From 6c89d8d35cfc3fabd38fdfc8463203ee42ed1c2d Mon Sep 17 00:00:00 2001
From: fjl5
Date: Mon, 13 Apr 2026 17:30:49 +0800
Subject: [PATCH 001/261] =?UTF-8?q?add=20prompt=5Fcache=5Fkey=20injection?=
=?UTF-8?q?=20for=20messages=E2=86=92responses?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.../service/openai_gateway_messages.go | 22 +++++++++++++++++++
1 file changed, 22 insertions(+)
diff --git a/backend/internal/service/openai_gateway_messages.go b/backend/internal/service/openai_gateway_messages.go
index a72b9bbf..2a0a72eb 100644
--- a/backend/internal/service/openai_gateway_messages.go
+++ b/backend/internal/service/openai_gateway_messages.go
@@ -121,6 +121,28 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
}
}
+ // For API key accounts (including OpenAI-compatible upstream gateways),
+ // ensure promptCacheKey is also propagated via the request body so that
+ // upstreams using the Responses API can derive a stable session identifier
+ // from prompt_cache_key. This makes our Anthropic /v1/messages compatibility
+ // path behave more like a native Responses client.
+ if account.Type == AccountTypeAPIKey {
+ if trimmedKey := strings.TrimSpace(promptCacheKey); trimmedKey != "" {
+ var reqBody map[string]any
+ if err := json.Unmarshal(responsesBody, &reqBody); err != nil {
+ return nil, fmt.Errorf("unmarshal for prompt cache key injection: %w", err)
+ }
+ if existing, ok := reqBody["prompt_cache_key"].(string); !ok || strings.TrimSpace(existing) == "" {
+ reqBody["prompt_cache_key"] = trimmedKey
+ updated, err := json.Marshal(reqBody)
+ if err != nil {
+ return nil, fmt.Errorf("remarshal after prompt cache key injection: %w", err)
+ }
+ responsesBody = updated
+ }
+ }
+ }
+
// 5. Get access token
token, _, err := s.GetAccessToken(ctx, account)
if err != nil {
--
GitLab
From 10699eeb34e0e9b602a50d43a1be434d4751d5af Mon Sep 17 00:00:00 2001
From: erio
Date: Thu, 16 Apr 2026 01:53:22 +0800
Subject: [PATCH 002/261] refactor: extract ReadUpstreamResponseBody to
deduplicate upstream response read + too-large error handling
Consolidates 9 call sites of resolveUpstreamResponseReadLimit + readUpstreamResponseBodyLimited + ErrUpstreamResponseBodyTooLarge error handling into a single ReadUpstreamResponseBody function with TooLargeWriter callback for API-format-specific error responses (Anthropic, OpenAI, countTokens).
---
backend/internal/service/gateway_service.go | 74 +++++--------------
.../service/gemini_messages_compat_service.go | 12 +--
.../service/openai_gateway_service.go | 24 +-----
.../service/upstream_response_limit.go | 43 +++++++++++
.../service/upstream_response_limit_test.go | 43 +++++++++++
5 files changed, 107 insertions(+), 89 deletions(-)
diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go
index c65e828a..4b4fc0bf 100644
--- a/backend/internal/service/gateway_service.go
+++ b/backend/internal/service/gateway_service.go
@@ -5120,19 +5120,8 @@ func (s *GatewayService) handleNonStreamingResponseAnthropicAPIKeyPassthrough(
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
}
- maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
- body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
+ body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, anthropicTooLargeError)
if err != nil {
- if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
- setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
- c.JSON(http.StatusBadGateway, gin.H{
- "type": "error",
- "error": gin.H{
- "type": "upstream_error",
- "message": "Upstream response too large",
- },
- })
- }
return nil, err
}
@@ -5498,19 +5487,8 @@ func (s *GatewayService) handleBedrockNonStreamingResponse(
c *gin.Context,
account *Account,
) (*ClaudeUsage, error) {
- maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
- body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
+ body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, anthropicTooLargeError)
if err != nil {
- if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
- setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
- c.JSON(http.StatusBadGateway, gin.H{
- "type": "error",
- "error": gin.H{
- "type": "upstream_error",
- "message": "Upstream response too large",
- },
- })
- }
return nil, err
}
@@ -7175,19 +7153,8 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
// 更新5h窗口状态
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
- maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
- body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
+ body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, anthropicTooLargeError)
if err != nil {
- if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
- setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
- c.JSON(http.StatusBadGateway, gin.H{
- "type": "error",
- "error": gin.H{
- "type": "upstream_error",
- "message": "Upstream response too large",
- },
- })
- }
return nil, err
}
@@ -8300,16 +8267,15 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
}
// 读取响应体
- maxReadBytes := resolveUpstreamResponseReadLimit(s.cfg)
- respBody, err := readUpstreamResponseBodyLimited(resp.Body, maxReadBytes)
+ countTokensTooLarge := func(c *gin.Context) {
+ s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large")
+ }
+ respBody, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, countTokensTooLarge)
_ = resp.Body.Close()
if err != nil {
- if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
- setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
- s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large")
- return err
+ if !errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
+ s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
}
- s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
return err
}
@@ -8323,15 +8289,12 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
if retryErr == nil {
resp = retryResp
- respBody, err = readUpstreamResponseBodyLimited(resp.Body, maxReadBytes)
+ respBody, err = ReadUpstreamResponseBody(resp.Body, s.cfg, c, countTokensTooLarge)
_ = resp.Body.Close()
if err != nil {
- if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
- setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
- s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large")
- return err
+ if !errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
+ s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
}
- s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
return err
}
}
@@ -8426,16 +8389,15 @@ func (s *GatewayService) forwardCountTokensAnthropicAPIKeyPassthrough(ctx contex
return fmt.Errorf("upstream request failed: %w", err)
}
- maxReadBytes := resolveUpstreamResponseReadLimit(s.cfg)
- respBody, err := readUpstreamResponseBodyLimited(resp.Body, maxReadBytes)
+ countTokensTooLarge := func(c *gin.Context) {
+ s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large")
+ }
+ respBody, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, countTokensTooLarge)
_ = resp.Body.Close()
if err != nil {
- if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
- setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
- s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large")
- return err
+ if !errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
+ s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
}
- s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
return err
}
diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go
index 5a9490f3..7a24071b 100644
--- a/backend/internal/service/gemini_messages_compat_service.go
+++ b/backend/internal/service/gemini_messages_compat_service.go
@@ -2424,18 +2424,8 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========================================")
}
- maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
- respBody, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
+ respBody, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError)
if err != nil {
- if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
- setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
- c.JSON(http.StatusBadGateway, gin.H{
- "error": gin.H{
- "type": "upstream_error",
- "message": "Upstream response too large",
- },
- })
- }
return nil, err
}
diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go
index ef97daad..064191bd 100644
--- a/backend/internal/service/openai_gateway_service.go
+++ b/backend/internal/service/openai_gateway_service.go
@@ -3010,18 +3010,8 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough(
resp *http.Response,
c *gin.Context,
) (*OpenAIUsage, error) {
- maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
- body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
+ body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError)
if err != nil {
- if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
- setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
- c.JSON(http.StatusBadGateway, gin.H{
- "error": gin.H{
- "type": "upstream_error",
- "message": "Upstream response too large",
- },
- })
- }
return nil, err
}
@@ -3919,18 +3909,8 @@ func extractOpenAIUsageFromJSONBytes(body []byte) (OpenAIUsage, bool) {
}
func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) {
- maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
- body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
+ body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError)
if err != nil {
- if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
- setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
- c.JSON(http.StatusBadGateway, gin.H{
- "error": gin.H{
- "type": "upstream_error",
- "message": "Upstream response too large",
- },
- })
- }
return nil, err
}
diff --git a/backend/internal/service/upstream_response_limit.go b/backend/internal/service/upstream_response_limit.go
index aecf69a3..a0444d52 100644
--- a/backend/internal/service/upstream_response_limit.go
+++ b/backend/internal/service/upstream_response_limit.go
@@ -4,8 +4,10 @@ import (
"errors"
"fmt"
"io"
+ "net/http"
"github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/gin-gonic/gin"
)
var ErrUpstreamResponseBodyTooLarge = errors.New("upstream response body too large")
@@ -36,3 +38,44 @@ func readUpstreamResponseBodyLimited(reader io.Reader, maxBytes int64) ([]byte,
}
return body, nil
}
+
+// TooLargeWriter 在响应超限时向客户端写格式化的错误响应。
+type TooLargeWriter func(c *gin.Context)
+
+// ReadUpstreamResponseBody 读取上游非流式响应体。
+// 超限时自动记录 ops error 并调用 onTooLarge 向客户端写错误。
+func ReadUpstreamResponseBody(reader io.Reader, cfg *config.Config, c *gin.Context, onTooLarge TooLargeWriter) ([]byte, error) {
+ maxBytes := resolveUpstreamResponseReadLimit(cfg)
+ body, err := readUpstreamResponseBodyLimited(reader, maxBytes)
+ if err != nil {
+ if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
+ setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
+ if onTooLarge != nil {
+ onTooLarge(c)
+ }
+ }
+ return nil, err
+ }
+ return body, nil
+}
+
+// anthropicTooLargeError 以 Anthropic Messages API 格式写入超限错误。
+func anthropicTooLargeError(c *gin.Context) {
+ c.JSON(http.StatusBadGateway, gin.H{
+ "type": "error",
+ "error": gin.H{
+ "type": "upstream_error",
+ "message": "Upstream response too large",
+ },
+ })
+}
+
+// openAITooLargeError 以 OpenAI / Gemini 格式写入超限错误。
+func openAITooLargeError(c *gin.Context) {
+ c.JSON(http.StatusBadGateway, gin.H{
+ "error": gin.H{
+ "type": "upstream_error",
+ "message": "Upstream response too large",
+ },
+ })
+}
diff --git a/backend/internal/service/upstream_response_limit_test.go b/backend/internal/service/upstream_response_limit_test.go
index b9e5cc6d..09283189 100644
--- a/backend/internal/service/upstream_response_limit_test.go
+++ b/backend/internal/service/upstream_response_limit_test.go
@@ -4,8 +4,10 @@ import (
"bytes"
"errors"
"testing"
+ "testing/iotest"
"github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
@@ -35,3 +37,44 @@ func TestReadUpstreamResponseBodyLimited(t *testing.T) {
require.True(t, errors.Is(err, ErrUpstreamResponseBodyTooLarge))
})
}
+
+func TestReadUpstreamResponseBody(t *testing.T) {
+ t.Run("within limit", func(t *testing.T) {
+ body, err := ReadUpstreamResponseBody(bytes.NewReader([]byte("ok")), nil, nil, nil)
+ require.NoError(t, err)
+ require.Equal(t, []byte("ok"), body)
+ })
+
+ t.Run("exceeds limit calls onTooLarge", func(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.UpstreamResponseReadMaxBytes = 3
+
+ called := false
+ onTooLarge := func(_ *gin.Context) { called = true }
+
+ body, err := ReadUpstreamResponseBody(bytes.NewReader([]byte("toolong")), cfg, nil, onTooLarge)
+ require.Nil(t, body)
+ require.True(t, errors.Is(err, ErrUpstreamResponseBodyTooLarge))
+ require.True(t, called)
+ })
+
+ t.Run("nil onTooLarge does not panic", func(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.UpstreamResponseReadMaxBytes = 3
+
+ body, err := ReadUpstreamResponseBody(bytes.NewReader([]byte("toolong")), cfg, nil, nil)
+ require.Nil(t, body)
+ require.True(t, errors.Is(err, ErrUpstreamResponseBodyTooLarge))
+ })
+
+ t.Run("io error does not call onTooLarge", func(t *testing.T) {
+ called := false
+ onTooLarge := func(_ *gin.Context) { called = true }
+
+ body, err := ReadUpstreamResponseBody(iotest.ErrReader(errors.New("disk failure")), nil, nil, onTooLarge)
+ require.Nil(t, body)
+ require.Error(t, err)
+ require.False(t, errors.Is(err, ErrUpstreamResponseBodyTooLarge))
+ require.False(t, called)
+ })
+}
--
GitLab
From 3944b3d216bbcc10472fccdc4f6606fa2b85106e Mon Sep 17 00:00:00 2001
From: KnowSky404
Date: Thu, 16 Apr 2026 02:01:50 +0000
Subject: [PATCH 003/261] fix: preserve openai ws flags in scheduler cache
---
.../internal/repository/scheduler_cache.go | 7 +++
.../repository/scheduler_cache_unit_test.go | 33 ++++++++++
...enai_account_scheduler_ws_snapshot_test.go | 62 +++++++++++++++++++
3 files changed, 102 insertions(+)
create mode 100644 backend/internal/repository/scheduler_cache_unit_test.go
create mode 100644 backend/internal/service/openai_account_scheduler_ws_snapshot_test.go
diff --git a/backend/internal/repository/scheduler_cache.go b/backend/internal/repository/scheduler_cache.go
index e9be8c7a..add0e501 100644
--- a/backend/internal/repository/scheduler_cache.go
+++ b/backend/internal/repository/scheduler_cache.go
@@ -426,6 +426,13 @@ func filterSchedulerExtra(extra map[string]any) map[string]any {
"window_cost_sticky_reserve",
"max_sessions",
"session_idle_timeout_minutes",
+ "openai_oauth_responses_websockets_v2_enabled",
+ "openai_oauth_responses_websockets_v2_mode",
+ "openai_apikey_responses_websockets_v2_enabled",
+ "openai_apikey_responses_websockets_v2_mode",
+ "responses_websockets_v2_enabled",
+ "openai_ws_enabled",
+ "openai_ws_force_http",
}
filtered := make(map[string]any)
for _, key := range keys {
diff --git a/backend/internal/repository/scheduler_cache_unit_test.go b/backend/internal/repository/scheduler_cache_unit_test.go
new file mode 100644
index 00000000..bcfd0e7a
--- /dev/null
+++ b/backend/internal/repository/scheduler_cache_unit_test.go
@@ -0,0 +1,33 @@
+//go:build unit
+
+package repository
+
+import (
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/stretchr/testify/require"
+)
+
+func TestBuildSchedulerMetadataAccount_KeepsOpenAIWSFlags(t *testing.T) {
+ account := service.Account{
+ ID: 42,
+ Platform: service.PlatformOpenAI,
+ Type: service.AccountTypeOAuth,
+ Extra: map[string]any{
+ "openai_oauth_responses_websockets_v2_enabled": true,
+ "openai_oauth_responses_websockets_v2_mode": service.OpenAIWSIngressModePassthrough,
+ "openai_ws_force_http": true,
+ "mixed_scheduling": true,
+ "unused_large_field": "drop-me",
+ },
+ }
+
+ got := buildSchedulerMetadataAccount(account)
+
+ require.Equal(t, true, got.Extra["openai_oauth_responses_websockets_v2_enabled"])
+ require.Equal(t, service.OpenAIWSIngressModePassthrough, got.Extra["openai_oauth_responses_websockets_v2_mode"])
+ require.Equal(t, true, got.Extra["openai_ws_force_http"])
+ require.Equal(t, true, got.Extra["mixed_scheduling"])
+ require.Nil(t, got.Extra["unused_large_field"])
+}
diff --git a/backend/internal/service/openai_account_scheduler_ws_snapshot_test.go b/backend/internal/service/openai_account_scheduler_ws_snapshot_test.go
new file mode 100644
index 00000000..c5de8203
--- /dev/null
+++ b/backend/internal/service/openai_account_scheduler_ws_snapshot_test.go
@@ -0,0 +1,62 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/stretchr/testify/require"
+)
+
+func TestOpenAIGatewayService_SelectAccountWithScheduler_UsesWSPassthroughSnapshotFlags(t *testing.T) {
+ ctx := context.Background()
+ groupID := int64(10105)
+ account := &Account{
+ ID: 35001,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 10,
+ Extra: map[string]any{
+ "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
+ },
+ }
+
+ snapshotCache := &openAISnapshotCacheStub{
+ snapshotAccounts: []*Account{account},
+ accountsByID: map[int64]*Account{account.ID: account},
+ }
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
+ cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeCtxPool
+
+ svc := &OpenAIGatewayService{
+ accountRepo: stubOpenAIAccountRepo{accounts: []Account{*account}},
+ cache: &stubGatewayCache{},
+ cfg: cfg,
+ schedulerSnapshot: &SchedulerSnapshotService{cache: snapshotCache},
+ concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ }
+
+ selection, decision, err := svc.SelectAccountWithScheduler(
+ ctx,
+ &groupID,
+ "",
+ "session_hash_ws_passthrough",
+ "gpt-5.1",
+ nil,
+ OpenAIUpstreamTransportResponsesWebsocketV2,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ require.Equal(t, account.ID, selection.Account.ID)
+ require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
+}
--
GitLab
From 836092a6665f3f8f50ef8f4b055b89d73c748310 Mon Sep 17 00:00:00 2001
From: KnowSky404
Date: Thu, 16 Apr 2026 02:13:04 +0000
Subject: [PATCH 004/261] fix: restore ctx pool ws mode option in account ui
---
frontend/src/components/account/BulkEditAccountModal.vue | 2 ++
frontend/src/components/account/CreateAccountModal.vue | 5 ++---
frontend/src/components/account/EditAccountModal.vue | 5 ++---
3 files changed, 6 insertions(+), 6 deletions(-)
diff --git a/frontend/src/components/account/BulkEditAccountModal.vue b/frontend/src/components/account/BulkEditAccountModal.vue
index 5461015b..13c30cf9 100644
--- a/frontend/src/components/account/BulkEditAccountModal.vue
+++ b/frontend/src/components/account/BulkEditAccountModal.vue
@@ -921,6 +921,7 @@ import {
getPresetMappingsByPlatform
} from '@/composables/useModelWhitelist'
import {
+ OPENAI_WS_MODE_CTX_POOL,
OPENAI_WS_MODE_OFF,
OPENAI_WS_MODE_PASSTHROUGH,
isOpenAIWSModeEnabled,
@@ -1069,6 +1070,7 @@ const isOpenAIModelRestrictionDisabled = computed(
const openAIWSModeOptions = computed(() => [
{ value: OPENAI_WS_MODE_OFF, label: t('admin.accounts.openai.wsModeOff') },
+ { value: OPENAI_WS_MODE_CTX_POOL, label: t('admin.accounts.openai.wsModeCtxPool') },
{ value: OPENAI_WS_MODE_PASSTHROUGH, label: t('admin.accounts.openai.wsModePassthrough') }
])
const openAIWSModeConcurrencyHintKey = computed(() =>
diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue
index 432fa080..2130c9ab 100644
--- a/frontend/src/components/account/CreateAccountModal.vue
+++ b/frontend/src/components/account/CreateAccountModal.vue
@@ -2932,7 +2932,7 @@ import { applyInterceptWarmup } from '@/components/account/credentialsBuilder'
import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format'
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
import {
- // OPENAI_WS_MODE_CTX_POOL,
+ OPENAI_WS_MODE_CTX_POOL,
OPENAI_WS_MODE_OFF,
OPENAI_WS_MODE_PASSTHROUGH,
isOpenAIWSModeEnabled,
@@ -3180,8 +3180,7 @@ const geminiSelectedTier = computed(() => {
const openAIWSModeOptions = computed(() => [
{ value: OPENAI_WS_MODE_OFF, label: t('admin.accounts.openai.wsModeOff') },
- // TODO: ctx_pool 选项暂时隐藏,待测试完成后恢复
- // { value: OPENAI_WS_MODE_CTX_POOL, label: t('admin.accounts.openai.wsModeCtxPool') },
+ { value: OPENAI_WS_MODE_CTX_POOL, label: t('admin.accounts.openai.wsModeCtxPool') },
{ value: OPENAI_WS_MODE_PASSTHROUGH, label: t('admin.accounts.openai.wsModePassthrough') }
])
diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue
index e5065d7a..1da32e2c 100644
--- a/frontend/src/components/account/EditAccountModal.vue
+++ b/frontend/src/components/account/EditAccountModal.vue
@@ -1858,7 +1858,7 @@ import { applyInterceptWarmup } from '@/components/account/credentialsBuilder'
import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format'
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
import {
- // OPENAI_WS_MODE_CTX_POOL,
+ OPENAI_WS_MODE_CTX_POOL,
OPENAI_WS_MODE_OFF,
OPENAI_WS_MODE_PASSTHROUGH,
isOpenAIWSModeEnabled,
@@ -2020,8 +2020,7 @@ const editWeeklyResetHour = ref(null)
const editResetTimezone = ref(null)
const openAIWSModeOptions = computed(() => [
{ value: OPENAI_WS_MODE_OFF, label: t('admin.accounts.openai.wsModeOff') },
- // TODO: ctx_pool 选项暂时隐藏,待测试完成后恢复
- // { value: OPENAI_WS_MODE_CTX_POOL, label: t('admin.accounts.openai.wsModeCtxPool') },
+ { value: OPENAI_WS_MODE_CTX_POOL, label: t('admin.accounts.openai.wsModeCtxPool') },
{ value: OPENAI_WS_MODE_PASSTHROUGH, label: t('admin.accounts.openai.wsModePassthrough') }
])
const openaiResponsesWebSocketV2Mode = computed({
--
GitLab
From a55ead5ea860c2517c625fbbe6cb2f1823382954 Mon Sep 17 00:00:00 2001
From: shaw
Date: Thu, 16 Apr 2026 16:42:40 +0800
Subject: [PATCH 005/261] chore: remove empty dir Antigravity-Manager
---
Antigravity-Manager | 1 -
1 file changed, 1 deletion(-)
delete mode 160000 Antigravity-Manager
diff --git a/Antigravity-Manager b/Antigravity-Manager
deleted file mode 160000
index a9d96bd5..00000000
--- a/Antigravity-Manager
+++ /dev/null
@@ -1 +0,0 @@
-Subproject commit a9d96bd54978c22d3033830debfe77aeeeee2500
--
GitLab
From 7ea8e7e6675c3e5b4c5c73578136420c76a4ee64 Mon Sep 17 00:00:00 2001
From: shaw
Date: Thu, 16 Apr 2026 17:19:32 +0800
Subject: [PATCH 006/261] chore: update sponsors
---
README.md | 5 +++++
README_CN.md | 5 +++++
README_JA.md | 5 +++++
assets/partners/logos/bmoplus.jpg | Bin 0 -> 7996 bytes
4 files changed, 15 insertions(+)
create mode 100644 assets/partners/logos/bmoplus.jpg
diff --git a/README.md b/README.md
index c2715eae..74ab9af2 100644
--- a/README.md
+++ b/README.md
@@ -91,6 +91,11 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
Thanks to AIGoCode for sponsoring this project! AIGoCode is an all-in-one platform that integrates Claude Code, Codex, and the latest Gemini models, providing you with stable, efficient, and highly cost-effective AI coding services. The platform offers flexible subscription plans, zero risk of account suspension, direct access with no VPN required, and lightning-fast responses. AIGoCode has prepared a special benefit for sub2api users: if you register via this link , you'll receive an extra 10% bonus credit on your first top-up!
+
+
+Huge thanks to BmoPlus for sponsoring this project! BmoPlus is a highly reliable AI account provider built strictly for heavy AI users and developers. They offer rock-solid, ready-to-use accounts and official top-up services for ChatGPT Plus / ChatGPT Pro (Full Warranty) / Claude Pro / Super Grok / Gemini Pro. By registering and ordering through BmoPlus - Premium AI Accounts & Top-ups , users can unlock the mind-blowing rate of 10% of the official GPT subscription price (90% OFF)
+
+
## Ecosystem
diff --git a/README_CN.md b/README_CN.md
index 0ace1f77..c701372c 100644
--- a/README_CN.md
+++ b/README_CN.md
@@ -90,6 +90,11 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的
感谢 AIGoCode 赞助了本项目!AIGoCode 是一站式集成 Claude Code、Codex 以及最新 Gemini 模型的综合平台,为您提供稳定、高效、高性价比的 AI 编程服务。平台提供灵活的订阅方案,零封号风险,免 VPN 直连,响应极速。AIGoCode 为 sub2api 用户准备了专属福利:通过此链接 注册,首次充值可额外获得 10% 赠送额度!
+
+
+感谢 BmoPlus 赞助了本项目!BmoPlus 是一家专为AI订阅重度用户打造的可靠 AI 账号代充服务商,提供稳定的 ChatGPT Plus / ChatGPT Pro(全程质保) / Claude Pro / Super Grok / Gemini Pro 的官方代充&成品账号。 通过BmoPlus AI成品号专卖/代充 注册下单的用户,可享GPT 官网订阅一折 的震撼价格!
+
+
## 生态项目
diff --git a/README_JA.md b/README_JA.md
index d74ca9ce..0d4db616 100644
--- a/README_JA.md
+++ b/README_JA.md
@@ -90,6 +90,11 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを
AIGoCode のご支援に感謝します!AIGoCode は Claude Code、Codex、最新の Gemini モデルを統合したオールインワンプラットフォームで、安定的かつ効率的でコストパフォーマンスに優れた AI コーディングサービスを提供します。柔軟なサブスクリプションプラン、アカウント停止リスクゼロ、VPN 不要の直接アクセス、超高速レスポンスが特長です。AIGoCode は sub2api ユーザー向けに特別特典を用意しています:こちらのリンク から登録すると、初回チャージ時に 10% のボーナスクレジットを追加プレゼント!
+
+
+本プロジェクトにご支援いただいた BmoPlus に感謝いたします!BmoPlusは、AIサブスクリプションのヘビーユーザー向けに特化した信頼性の高いAIアカウントサービスプロバイダーであり、安定した ChatGPT Plus / ChatGPT Pro (完全保証) / Claude Pro / Super Grok / Gemini Pro の公式代行チャージおよび即納アカウントを提供しています。こちらのBmoPlus AIアカウント専門店/代行チャージ 経由でご登録・ご注文いただいたユーザー様は、GPTを 公式サイト価格の約1割(90% OFF) という驚異的な価格でご利用いただけます!
+
+
## エコシステム
diff --git a/assets/partners/logos/bmoplus.jpg b/assets/partners/logos/bmoplus.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..1a9b4d8b7b2d626d860f327a49fe0bf3ade711fb
GIT binary patch
literal 7996
zcmb7o1yr0tvhLs>7-Z1k&fpTg$y>i$ct^Ish300tJyepv#b!@&SxvEZptvqvdea
zGOycfnABnoDwcYd>=!B}IF&h8&UFI3x?@s4$Dit|(QYrc7Ro?`skB=cf3<~i7S`ii
z^iM`JT^=QdYSdk(RAqZd?Nx5kwh8!1?0B4(MY?y7dQB^N8|+&D$WdeC+R9_V
zW8>tSDd|@^{Fy!`Rbuf(SgoD(_Z^1~Dt$VA>L2Gj-ZkZ#GYu>j#WuM|P2SuG*>5Cx
zY>$^bC_k1vK<2nyYpqy0xi0nuoW@YO7*-33FTIoZEtkqI7j+Z(7?x0=X1dOqTTYf<
z3EVZbvpVANE-uG%2)KICOPeCl|^I{&G^K}hwaW8mrCn%x+0al5t$XlHZ^qVoe^
zW^sc3O-X_mN6WSz2d
zRw^`i{u>4V3rPeTfGk)5SXdZ1I5=3uf5agSEC3D&kBx)H!HJ7Ufxt%2E+(#GOi3l`
zpEHVm#<2o>}I*`-rvmeQ64QUhgw^Re6AmZz|Vky_iI2{KVN*cKC3`*I!wrhs@!mD@;gN
zX}Okx5Ytodm_&R0q=^-+xUaSqb}i9xzx3FY?dqmym2Ft@cWxAz7=Oz&u2UEVuxGN=)-7KakC}uIUSfC
z8Knp^3v1Q{fpsEu`o$9pQ{AJ+LD}>JJa5pHlNnu*);LveQPJ3lrAtn1ueeUUib^dY
zwmiC~juKqI9mIa9Zfi30LYR{Lw%@q8e_l(uEx9mQl?||gFKj-!cBx4&9Aa$I`kV0xiMX(Lb^V6ao!FMP=}XSK4Var?ty
zK#ij6mzbWsdOQutMY}|S=xEEy6*^-9if=;J{E5=U^7I6rp?u51W6&(wd13p>{VW?E
zH7fiQ-NEz(wnvXz+BF{D7q)|<)543_+^>^8MlkPFn@^TL_RfBG>ten5u|?)KW>KBL
zU9+L};8i}U9{d`8PCHPl)xG*u3qfDTQuLtcf+_jSmQkfwL#2F?f4>Uy)8P;VJK>$Opg!r7BvSqaL4!I`;HKA
zPn-K-{*2K2!@d93cLQ1u3sq&f4MXmrVpsY@%w$&VQ2tG*^hgJbtLm-BtHu4tW$REV{)$tUaMoZr(eq
zZ_ib0Bqpq0U}tp5Sbx+Fh)TcY?sHnp);ElBMxsEkKglA0UdSr_{bj`Y-WG9gJ$>k1
zbnn(;^CO?|N@afGDsRt;NH0lL7k)W7r~GZ$3mUSx{afBaI_ze9AOw+>~=r(R6W2~T-H~DTT3@T92p)7
zDAh%;)4Xwn#)g1DrHR^DUA0@BI#c`
zmPGsm7e#)nu2X${%0aj1pGbAT;7nro=QFBxgrhfz3oozp_Lj3-QBrQ8Kf^wS<6Xmv
zKS6%J{Sxco{8YtwlktlVj@{RAbX@hkBwk!k8CWBPNgLe(S<$)lV~fLS8b>)=z^Z
zUf_0m9;qtXecvapt{qUCSUp_g@GzcSmt9Q**0+jQCw5QCG4D*K&;;}r2M4#eTci;U
zgG=g8EHt&7Y6&f4Tx+L0F0$<-D@=7YCB!W^OI_OIKEp0)P3U-Prq2-V(~W@dEFxvq
zw(pB%`#Ng($j6Y5eyil$Wdf}K@n0E
zmAUj@IcG`&@!wbx91N)9t2T##2isnPq!_)jS6;kkk}iUg)M4vqGStGgY9F{@(jHL+_(g2V|RZa;H9}L~ggsE7;ZcKjC+Vycw8~do#p4RJDeD
z64BLc7%lrimCv|d8Y>b=*tmWwyE>M+>ygeZJ$T5Qx#$}LT{9=NP(q<80un63KOzJG
z1H__$Wf#LH=TJ3)!{K!HkI!K8u@~iqswL}zwR!h(bubte;(Qzgok5Btp
zrgfEVIo;7M)zRGp&s$5=-r}iKacSAXo$Flm`1p^$$}w(%R+z6a@0GqCtu?)VKj%a9g4*f8ZWDMP9;GoC0CbxWICCwwZg
z;P8TEj$&H2VrU-n+GhVIEtX?8VrUFPl3qLH)>PfAGxvVvsZuJe!#?T#v;&*|UqHb~d>T!PW&;Oy)FL<2`%@fZ%jF+bCeEfHybKUR
zLY{ib=(1sMuUyFmyJnPrl!lvJg+e6rE~5(_chf$%=AIUXCP7JihkKh;(01JU!|PvT
zJ5BX%O0n!4-mz$7m^I!On%g?lW%^N)7u16(ilkRSBr)W+!H6FIT-EsWPviMhlktBv
z)b#hSSXg2UeCuV
z@{t%J@prGxr*fL($m_I#4O`Ru&Rp}yDR1)lSL&!^%U8?3M?Pq-lJYvD@nnOjtk93m
zs=A+I@Q<82Ozf5Bi8<{pLzlP-Op(pHJ(OpTNIaV;b!%2f=7u&t#kKJMq8f;#_Qbp)
zMY<5g-1a0JJR>^+f%!Nset)j^G@>QzS^TuElVYu$y?`N@PC$iy$*bMe-8T57%GJwDBnmAFAr_U*=ES04k>V`xF_}8(!_kmO*)=t4gTzG5vKu1UPuue;t-!fE4Ul
zVyY(2{yA0TZ1HEtT~1TyqALG{Xd)EP%r2j(!7l(vF6h9qM_2U3HBh51(-)nPxig}~
zCqZ&{Zi*CtLtzt$lv~Oo7co3)6jYPpsp2|;Dz~2ivzOMU_Lll#`Ym-dEo~$Iv%ZA{
zIO1^7dQ)DD5mFy8f}Uf3q^83tqQDh0JL`ZE*Z*1^%mB
zFeu$v7Zz7U-FL`x1kxtxxmE#JMAeg_EoX+<;#S{%SJK9KANZAe+&GYC
z@G$0Ew()7G3_d?f#74Ib9%Jye0mgAicD&TCv47
zb&i4%2SmNMz2(g5Cxh`Cte9#di*BjSXkZ0}&
z;}P?ti4JB*_UA|#20||mq$U~I67#bf)>mTY)5v)!4j(I|A@=-Sk|V>wluSPKx5>9y
z;R%^O!29pwV~pGwuaMLShWqD0q>tt`aMdt7)B{WU$jmttwACt5;
z^|8Jo<(r?;>YCi(P+@Bz#aL^D8jW@L1bRB(x$hnJxCVJW``qVRZj|sJ#cPrTm+N_v673iFll3o(;Al+cM`ks*9+IJDjjErXGBpJI(Qi(N^7D=
zNwa1-rQK;;2+?^(uLfw5?Cx{h)30a)QPi72ig6`a!nimD4G9EAnr1&-u829%&rr;a
zs4wKrRamG3f=N~*5Na_~ElfU3bW$v5e?u4t?%drpq3YTUq=I*FK3V&=Yd{*be!pBU0ep_`6qk)5ipyb~io>=NqcC8H0wUuOR<#P1&XSlNaR+XqJZ^gTU7W_=
z?{Q%WKug)4luPW_6RO5(Bc;x0rp-C~$FDWIHF6k9IWY2h);59Hw6*x$jJv(z%(cHk
z$<^35nmFnS1oJ3B*qpL@20P6SjBp!SE#2(n*qece4YzP_?WyRaiTJbv_;-k(_fm
z3-%dANOBr8i^cF;fxv|jIUeflbyPW3be|~}e`U2iN^bNTM3=gN)SN4Ze?0Vynm>>f
zkwkUA!;>qRDKw;(_gy6mf%p2Q-fj&@A@%`sXjcz`C3e{9pQj=>H&a
z5nKZIKWMng%2yWt`pCAW%(cRJcR9{=IVRXM$~E)%l`CG-7G?TsV<+ba>z~OQi{nB+
z;}M=)jply-F-dfECd|rOD&)y5JZ(F*)huA1o0*8(#})1mmt)Ah|7&xeRWD(LoLQnF
zmBoANmIT?N7-z9RT-x%6Rzkka|6n$u9GjRd9?5^;{+aqeb!f3*grYMTSa<{k=+)Xk
z;Mt#IA*Nc@1+5lSXXGmJIU87_#^--)#UCV(_4`0U2S?6oaUoJJAji=K2P8c5z6GmW
zA@b6j-Is3l)tkw7Lmx-`wKv_u)@Hh^093U4FTuI
zd76B2)bn!`R!AO$Ww(8^M6NM)f~I6|PyDj&NR}@qAj#6t{FnU60N!y&H(7ea{h^eFY+j?$$((xN-0L
z$40(V8>Pgk0mW
z&J4;a+c!mV^(2+}5Q72!@elaLzuq|x4w$u(^@UImO(^uo-E(%PCDuhIlurG|4He)Z
z(x@~J9X7@MRYf!@gLgN%*};!vl02LZBb$->91Z=d)WAPisL(?iEX+SoXn?;EK*fX{
zdN4C~&N=)0J^w&%tlNz;ufm=f>6~m-Y&cz!JjnZPyq?Z)9ixt%PogzwN;?$mqog#CjIcJLbSRN033ikgmt8_o(&_imC3tT`diQbWto^
zC$8DWB%JH2gHdj*o=4dP%_L>7u~*x+n%3>3vT5erTc}3)ec7~;$=sag3IpzG>G5M%
zJB21*Tii9Qhogs`kUC7cr>Ua|3L@Elh^5YwI*vPtvu-OAGVJAxBcX{IG4My)t1_$*
zbcP$<<-@$`*d0FMG)0#mX?=Vp(uMwXGi~%1ElkV=>}%!=-4!qvbWfnpjtJe9|8jPi
zKfCf57h%QcRCS$Ak+T_ZK!K6dzmO5uvyg93N@Ui>&M=%5Yq7Umu~MlS;)qCyfNfB|
z(T69RH~J%FC;IW09l<3uZaceu`Fc%SX42^-OI}jl3H?CgUqcZi&$!1i>ua35s5mDY
zqLeyGiJ_2KYzr1lT7KKuvs4+roEDi~jU2t2^{$Mzw0Hsa(UCfA4ZVJwqRS7J8{17ajc<(;9NZG8hnOv_CEJrEvd8EV|`~!mO
z$$iYc=jq_Z%2sll_x7<~S?}ai!DrRFTP_Kk{`nE>k-l!M6P2$Rmsb;h%e-&oGgxz=
zQalFH-2Mz65m6kLKhLY_-$~dRp?Lvx7FBSBw^LFF%zgc$&h1&nMS`;O@%u-8F*8krKC1?d`WZ~W^oBn
zc*X&PfoN|Ym07No6NZ>eE!UYwwGq3t6UUttmyq*Uv>ur*EmCE7%t$d)Of`2s%gTKA
z?q$dInJJ9Hj2~Hv?`qgUrFn%c2sC%yM>eSxsneLbZIU7Iz$rpEL6#IE;xmFov0`gy^RAOP&!3CL2X7u#KYQj~^vh%5LGPy}O8txKSF+4lx&E}2pp6`*w
zY>p(ArE)r#3}Xfa`VIDvT0bHatl=D-U!S-M@XViWd^*BI$j;5c
zVIMKz8#T)GbsYxd+T47mttfwY^84n?IV30n#F&Ye1YXMS-`&``=3ki0^0qQX3S}&YSqi!G*3mm
z3HeU&p3JJD`30cDz|FQZu!%lK|2=>98_&Ryu`U=LxlIAmMj79&TsjnN)|dt3H&?p&
zyuosY>Dx=_FXQBcaw;?SqaS_f{ri)y$mO
zq-$TVeb?2Gp1^&>?T2@hxpvv^Xf1Fz!x~-o7KRcJ!RjUti86(|M(8TTk{yYc4ETb`
zbYf3%LxTbD4v~t55CJK{>utO3w72)t2pD>KAp^|Rx=C1q986;Y$8p8}MU})B+daS)
z%(ibuQU=ziVScu|&8}+DmU}Bom%AqgQ=5c;aM2c8uY~f`pU9QD
z5K-oBBl_cE*tK9yx;*yyYQ<@AWyPP_?1#7`{=l(jE6M2r%Fc@Z{BpH`I?n6WLyj0d
zL+nUd|DDmQE}5i7KF8j|B;StB1(Ze7yZxi_qMIvd6UW6TdIN+13BGs;dA;Zyw
z&LH7sGaP(MsOGZ~^3$kSJ}dteAzN|7!G{i+b$<>~#M!zjGvXQp#Zyg&@U1L@_YoN<
z?zc$$LXRbVr!$@pj~;%L>(=l?*1D<{I8WQKY_Oe1AbdA_Zh{Y=6w_^;0@g;(d0?$q
z7LW7fxs!WWPFb)pyAVfGf+n%k&V^TLJw;^FQvm5qprF!=;7kPU5*!o-Og`IYjX7{V
z#8!ud%@$cTBD%dIwAN+)ehZ!H^=l++c+c+A6@?`Gy4m4R*u!QqgRSS&@RPyd`9PG{
zNDNZFXwvd2SWJV!u`_LNp#0hC-;88^{wo$kx$W%Vxa8_qJcK`OcKZTwkBT>e`^?h6O4+}q
ze@QE+%SpeD(OcA6?~~*d*?i@YaQvyvH!JpX>e(CZ0`ZDG2-Uu~Q5b?JHl5!xB7(hy
z2M-qDrAmO|i+{(Ecjyaf#3BRweS5c|&Az5>d?o6?ZBpxE*6}Jx%J`}}ci(npL-++S
z@&En^SPY%2*N+~0MLBKT_oPLH!1dP0K}5yhwJF}2_
Date: Thu, 16 Apr 2026 19:09:40 +0800
Subject: [PATCH 007/261] fix: fix outbox watermark context expiry and add
in-batch group rebuild dedup
Fixes #1691
- pollOutbox() reused a 10s context for SetOutboxWatermark after event
processing could take much longer, causing "outbox watermark write
failed: context deadline exceeded". The watermark never advanced so
the same 200 events were reprocessed every poll cycle, spiking CPU.
Now uses an independent 5s context with up to 3 retries (200ms apart).
- When multiple Codex accounts sharing the same 21-22 groups are all
rate-limited in quick succession, each account_changed event triggered
redundant bucket rebuild attempts for the same groups. Introduce
batchSeenKey{groupID, platform} and thread a seen map through the
handler chain; rebuildBucketsForPlatform skips (group, platform) pairs
already rebuilt within the same poll batch (~80% fewer rebuild calls in
the 5-accounts-same-groups scenario).
Co-Authored-By: Claude Sonnet 4.6 (1M context)
---
.../service/scheduler_snapshot_service.go | 82 +++++++++++++------
1 file changed, 57 insertions(+), 25 deletions(-)
diff --git a/backend/internal/service/scheduler_snapshot_service.go b/backend/internal/service/scheduler_snapshot_service.go
index d1330abb..5ead45bc 100644
--- a/backend/internal/service/scheduler_snapshot_service.go
+++ b/backend/internal/service/scheduler_snapshot_service.go
@@ -20,6 +20,14 @@ var (
const outboxEventTimeout = 2 * time.Minute
+// batchSeenKey tracks which (groupID, platform) bucket sets have already been
+// rebuilt within a single pollOutbox call, to avoid redundant work when multiple
+// account_changed events share the same groups.
+type batchSeenKey struct {
+ groupID int64
+ platform string
+}
+
type SchedulerSnapshotService struct {
cache SchedulerCache
outboxRepo SchedulerOutboxRepository
@@ -244,9 +252,10 @@ func (s *SchedulerSnapshotService) pollOutbox() {
}
watermarkForCheck := watermark
+ seen := make(map[batchSeenKey]struct{})
for _, event := range events {
eventCtx, cancel := context.WithTimeout(context.Background(), outboxEventTimeout)
- err := s.handleOutboxEvent(eventCtx, event)
+ err := s.handleOutboxEvent(eventCtx, event, seen)
cancel()
if err != nil {
logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] outbox handle failed: id=%d type=%s err=%v", event.ID, event.EventType, err)
@@ -255,8 +264,20 @@ func (s *SchedulerSnapshotService) pollOutbox() {
}
lastID := events[len(events)-1].ID
- if err := s.cache.SetOutboxWatermark(ctx, lastID); err != nil {
- logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] outbox watermark write failed: %v", err)
+ wmCtx, wmCancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer wmCancel()
+ var wmErr error
+ for i := range 3 {
+ wmErr = s.cache.SetOutboxWatermark(wmCtx, lastID)
+ if wmErr == nil {
+ break
+ }
+ if i < 2 {
+ time.Sleep(200 * time.Millisecond)
+ }
+ }
+ if wmErr != nil {
+ logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] outbox watermark write failed: %v", wmErr)
} else {
watermarkForCheck = lastID
}
@@ -264,18 +285,18 @@ func (s *SchedulerSnapshotService) pollOutbox() {
s.checkOutboxLag(ctx, events[0], watermarkForCheck)
}
-func (s *SchedulerSnapshotService) handleOutboxEvent(ctx context.Context, event SchedulerOutboxEvent) error {
+func (s *SchedulerSnapshotService) handleOutboxEvent(ctx context.Context, event SchedulerOutboxEvent, seen map[batchSeenKey]struct{}) error {
switch event.EventType {
case SchedulerOutboxEventAccountLastUsed:
return s.handleLastUsedEvent(ctx, event.Payload)
case SchedulerOutboxEventAccountBulkChanged:
- return s.handleBulkAccountEvent(ctx, event.Payload)
+ return s.handleBulkAccountEvent(ctx, event.Payload, seen)
case SchedulerOutboxEventAccountGroupsChanged:
- return s.handleAccountEvent(ctx, event.AccountID, event.Payload)
+ return s.handleAccountEvent(ctx, event.AccountID, event.Payload, seen)
case SchedulerOutboxEventAccountChanged:
- return s.handleAccountEvent(ctx, event.AccountID, event.Payload)
+ return s.handleAccountEvent(ctx, event.AccountID, event.Payload, seen)
case SchedulerOutboxEventGroupChanged:
- return s.handleGroupEvent(ctx, event.GroupID)
+ return s.handleGroupEvent(ctx, event.GroupID, seen)
case SchedulerOutboxEventFullRebuild:
return s.triggerFullRebuild("outbox")
default:
@@ -309,7 +330,7 @@ func (s *SchedulerSnapshotService) handleLastUsedEvent(ctx context.Context, payl
return s.cache.UpdateLastUsed(ctx, updates)
}
-func (s *SchedulerSnapshotService) handleBulkAccountEvent(ctx context.Context, payload map[string]any) error {
+func (s *SchedulerSnapshotService) handleBulkAccountEvent(ctx context.Context, payload map[string]any, seen map[batchSeenKey]struct{}) error {
if payload == nil {
return nil
}
@@ -323,15 +344,15 @@ func (s *SchedulerSnapshotService) handleBulkAccountEvent(ctx context.Context, p
}
ids := make([]int64, 0, len(rawIDs))
- seen := make(map[int64]struct{}, len(rawIDs))
+ seenIDs := make(map[int64]struct{}, len(rawIDs))
for _, id := range rawIDs {
if id <= 0 {
continue
}
- if _, exists := seen[id]; exists {
+ if _, exists := seenIDs[id]; exists {
continue
}
- seen[id] = struct{}{}
+ seenIDs[id] = struct{}{}
ids = append(ids, id)
}
if len(ids) == 0 {
@@ -384,10 +405,10 @@ func (s *SchedulerSnapshotService) handleBulkAccountEvent(ctx context.Context, p
for gid := range rebuildGroupSet {
rebuildGroupIDs = append(rebuildGroupIDs, gid)
}
- return s.rebuildByGroupIDs(ctx, rebuildGroupIDs, "account_bulk_change")
+ return s.rebuildByGroupIDs(ctx, rebuildGroupIDs, "account_bulk_change", seen)
}
-func (s *SchedulerSnapshotService) handleAccountEvent(ctx context.Context, accountID *int64, payload map[string]any) error {
+func (s *SchedulerSnapshotService) handleAccountEvent(ctx context.Context, accountID *int64, payload map[string]any, seen map[batchSeenKey]struct{}) error {
if accountID == nil || *accountID <= 0 {
return nil
}
@@ -408,7 +429,7 @@ func (s *SchedulerSnapshotService) handleAccountEvent(ctx context.Context, accou
return err
}
}
- return s.rebuildByGroupIDs(ctx, groupIDs, "account_miss")
+ return s.rebuildByGroupIDs(ctx, groupIDs, "account_miss", seen)
}
return err
}
@@ -420,18 +441,18 @@ func (s *SchedulerSnapshotService) handleAccountEvent(ctx context.Context, accou
if len(groupIDs) == 0 {
groupIDs = account.GroupIDs
}
- return s.rebuildByAccount(ctx, account, groupIDs, "account_change")
+ return s.rebuildByAccount(ctx, account, groupIDs, "account_change", seen)
}
-func (s *SchedulerSnapshotService) handleGroupEvent(ctx context.Context, groupID *int64) error {
+func (s *SchedulerSnapshotService) handleGroupEvent(ctx context.Context, groupID *int64, seen map[batchSeenKey]struct{}) error {
if groupID == nil || *groupID <= 0 {
return nil
}
groupIDs := []int64{*groupID}
- return s.rebuildByGroupIDs(ctx, groupIDs, "group_change")
+ return s.rebuildByGroupIDs(ctx, groupIDs, "group_change", seen)
}
-func (s *SchedulerSnapshotService) rebuildByAccount(ctx context.Context, account *Account, groupIDs []int64, reason string) error {
+func (s *SchedulerSnapshotService) rebuildByAccount(ctx context.Context, account *Account, groupIDs []int64, reason string, seen map[batchSeenKey]struct{}) error {
if account == nil {
return nil
}
@@ -441,21 +462,21 @@ func (s *SchedulerSnapshotService) rebuildByAccount(ctx context.Context, account
}
var firstErr error
- if err := s.rebuildBucketsForPlatform(ctx, account.Platform, groupIDs, reason); err != nil && firstErr == nil {
+ if err := s.rebuildBucketsForPlatform(ctx, account.Platform, groupIDs, reason, seen); err != nil && firstErr == nil {
firstErr = err
}
if account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled() {
- if err := s.rebuildBucketsForPlatform(ctx, PlatformAnthropic, groupIDs, reason); err != nil && firstErr == nil {
+ if err := s.rebuildBucketsForPlatform(ctx, PlatformAnthropic, groupIDs, reason, seen); err != nil && firstErr == nil {
firstErr = err
}
- if err := s.rebuildBucketsForPlatform(ctx, PlatformGemini, groupIDs, reason); err != nil && firstErr == nil {
+ if err := s.rebuildBucketsForPlatform(ctx, PlatformGemini, groupIDs, reason, seen); err != nil && firstErr == nil {
firstErr = err
}
}
return firstErr
}
-func (s *SchedulerSnapshotService) rebuildByGroupIDs(ctx context.Context, groupIDs []int64, reason string) error {
+func (s *SchedulerSnapshotService) rebuildByGroupIDs(ctx context.Context, groupIDs []int64, reason string, seen map[batchSeenKey]struct{}) error {
groupIDs = s.normalizeGroupIDs(groupIDs)
if len(groupIDs) == 0 {
return nil
@@ -463,19 +484,30 @@ func (s *SchedulerSnapshotService) rebuildByGroupIDs(ctx context.Context, groupI
platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity}
var firstErr error
for _, platform := range platforms {
- if err := s.rebuildBucketsForPlatform(ctx, platform, groupIDs, reason); err != nil && firstErr == nil {
+ if err := s.rebuildBucketsForPlatform(ctx, platform, groupIDs, reason, seen); err != nil && firstErr == nil {
firstErr = err
}
}
return firstErr
}
-func (s *SchedulerSnapshotService) rebuildBucketsForPlatform(ctx context.Context, platform string, groupIDs []int64, reason string) error {
+func (s *SchedulerSnapshotService) rebuildBucketsForPlatform(ctx context.Context, platform string, groupIDs []int64, reason string, seen map[batchSeenKey]struct{}) error {
if platform == "" {
return nil
}
var firstErr error
for _, gid := range groupIDs {
+ // Within a single poll batch, skip (groupID, platform) pairs that were
+ // already rebuilt. The first rebuild loads fresh DB data for all accounts
+ // in the group, so subsequent rebuilds for the same group+platform within
+ // the same batch are redundant.
+ if seen != nil {
+ key := batchSeenKey{gid, platform}
+ if _, exists := seen[key]; exists {
+ continue
+ }
+ seen[key] = struct{}{}
+ }
if err := s.rebuildBucket(ctx, SchedulerBucket{GroupID: gid, Platform: platform, Mode: SchedulerModeSingle}, reason); err != nil && firstErr == nil {
firstErr = err
}
--
GitLab
From 697c41a3f6d67671942ef941273095b02043dfc7 Mon Sep 17 00:00:00 2001
From: Elysia <1628615876@qq.com>
Date: Thu, 16 Apr 2026 20:41:40 +0800
Subject: [PATCH 008/261] fix: create fresh context per watermark write retry
attempt
Each retry in the SetOutboxWatermark loop now gets its own 5s context.
Previously a shared context could already be expired when the second or
third attempt ran, making the retries pointless.
Co-Authored-By: Claude Sonnet 4.6 (1M context)
---
backend/internal/service/scheduler_snapshot_service.go | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/backend/internal/service/scheduler_snapshot_service.go b/backend/internal/service/scheduler_snapshot_service.go
index 5ead45bc..62b6993d 100644
--- a/backend/internal/service/scheduler_snapshot_service.go
+++ b/backend/internal/service/scheduler_snapshot_service.go
@@ -264,11 +264,11 @@ func (s *SchedulerSnapshotService) pollOutbox() {
}
lastID := events[len(events)-1].ID
- wmCtx, wmCancel := context.WithTimeout(context.Background(), 5*time.Second)
- defer wmCancel()
var wmErr error
for i := range 3 {
+ wmCtx, wmCancel := context.WithTimeout(context.Background(), 5*time.Second)
wmErr = s.cache.SetOutboxWatermark(wmCtx, lastID)
+ wmCancel()
if wmErr == nil {
break
}
--
GitLab
From a789c8c4c70c72aa02a81736fc8ab8d9b8571f14 Mon Sep 17 00:00:00 2001
From: shaw
Date: Fri, 17 Apr 2026 09:37:25 +0800
Subject: [PATCH 009/261] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81opus-4.7?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
backend/internal/domain/constants.go | 2 +
.../internal/pkg/antigravity/claude_types.go | 1 +
.../pkg/antigravity/request_transformer.go | 12 +-
backend/internal/pkg/claude/constants.go | 6 +
backend/internal/service/billing_service.go | 6 +
backend/internal/service/pricing_service.go | 114 +++++++++++-------
frontend/src/composables/useModelWhitelist.ts | 7 +-
7 files changed, 101 insertions(+), 47 deletions(-)
diff --git a/backend/internal/domain/constants.go b/backend/internal/domain/constants.go
index 429486c3..a57f7067 100644
--- a/backend/internal/domain/constants.go
+++ b/backend/internal/domain/constants.go
@@ -71,6 +71,7 @@ const (
// 与前端 useModelWhitelist.ts 中的 antigravityDefaultMappings 保持一致
var DefaultAntigravityModelMapping = map[string]string{
// Claude 白名单
+ "claude-opus-4-7": "claude-opus-4-7", // 官方模型
"claude-opus-4-6-thinking": "claude-opus-4-6-thinking", // 官方模型
"claude-opus-4-6": "claude-opus-4-6-thinking", // 简称映射
"claude-opus-4-5-thinking": "claude-opus-4-6-thinking", // 迁移旧模型
@@ -120,6 +121,7 @@ var DefaultAntigravityModelMapping = map[string]string{
// aws_region 自动调整为匹配的区域前缀(如 eu.、apac.、jp. 等)
var DefaultBedrockModelMapping = map[string]string{
// Claude Opus
+ "claude-opus-4-7": "us.anthropic.claude-opus-4-7-v1",
"claude-opus-4-6-thinking": "us.anthropic.claude-opus-4-6-v1",
"claude-opus-4-6": "us.anthropic.claude-opus-4-6-v1",
"claude-opus-4-5-thinking": "us.anthropic.claude-opus-4-5-20251101-v1:0",
diff --git a/backend/internal/pkg/antigravity/claude_types.go b/backend/internal/pkg/antigravity/claude_types.go
index ce144bb9..0b8ae5f2 100644
--- a/backend/internal/pkg/antigravity/claude_types.go
+++ b/backend/internal/pkg/antigravity/claude_types.go
@@ -154,6 +154,7 @@ var claudeModels = []modelDef{
{ID: "claude-sonnet-4-5-thinking", DisplayName: "Claude Sonnet 4.5 Thinking", CreatedAt: "2025-09-29T00:00:00Z"},
{ID: "claude-opus-4-6", DisplayName: "Claude Opus 4.6", CreatedAt: "2026-02-05T00:00:00Z"},
{ID: "claude-opus-4-6-thinking", DisplayName: "Claude Opus 4.6 Thinking", CreatedAt: "2026-02-05T00:00:00Z"},
+ {ID: "claude-opus-4-7", DisplayName: "Claude Opus 4.7", CreatedAt: "2026-04-17T00:00:00Z"},
{ID: "claude-sonnet-4-6", DisplayName: "Claude Sonnet 4.6", CreatedAt: "2026-02-17T00:00:00Z"},
}
diff --git a/backend/internal/pkg/antigravity/request_transformer.go b/backend/internal/pkg/antigravity/request_transformer.go
index d13a8498..b5de8166 100644
--- a/backend/internal/pkg/antigravity/request_transformer.go
+++ b/backend/internal/pkg/antigravity/request_transformer.go
@@ -582,8 +582,12 @@ func maxOutputTokensLimit(model string) int {
return maxOutputTokensUpperBound
}
-func isAntigravityOpus46Model(model string) bool {
- return strings.HasPrefix(strings.ToLower(model), "claude-opus-4-6")
+// isAntigravityOpusHighTierModel 判断是否为高阶 Opus 模型(4.6+),
+// 用于 adaptive thinking 时覆写为高预算。
+func isAntigravityOpusHighTierModel(model string) bool {
+ lower := strings.ToLower(model)
+ return strings.HasPrefix(lower, "claude-opus-4-6") ||
+ strings.HasPrefix(lower, "claude-opus-4-7")
}
func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
@@ -605,12 +609,12 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
}
// - thinking.type=enabled:budget_tokens>0 用显式预算
- // - thinking.type=adaptive:仅在 Antigravity 的 Opus 4.6 上覆写为 (24576)
+ // - thinking.type=adaptive:在 Antigravity 的高阶 Opus(4.6+)上覆写为 (24576)
budget := -1
if req.Thinking.BudgetTokens > 0 {
budget = req.Thinking.BudgetTokens
}
- if req.Thinking.Type == "adaptive" && isAntigravityOpus46Model(req.Model) {
+ if req.Thinking.Type == "adaptive" && isAntigravityOpusHighTierModel(req.Model) {
budget = ClaudeAdaptiveHighThinkingBudgetTokens
}
diff --git a/backend/internal/pkg/claude/constants.go b/backend/internal/pkg/claude/constants.go
index dfca252f..21c723d2 100644
--- a/backend/internal/pkg/claude/constants.go
+++ b/backend/internal/pkg/claude/constants.go
@@ -83,6 +83,12 @@ var DefaultModels = []Model{
DisplayName: "Claude Opus 4.6",
CreatedAt: "2026-02-06T00:00:00Z",
},
+ {
+ ID: "claude-opus-4-7",
+ Type: "model",
+ DisplayName: "Claude Opus 4.7",
+ CreatedAt: "2026-04-17T00:00:00Z",
+ },
{
ID: "claude-sonnet-4-6",
Type: "model",
diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go
index 763abadb..32a54cbe 100644
--- a/backend/internal/service/billing_service.go
+++ b/backend/internal/service/billing_service.go
@@ -191,6 +191,9 @@ func (s *BillingService) initFallbackPricing() {
// Claude 4.6 Opus (与4.5同价)
s.fallbackPrices["claude-opus-4.6"] = s.fallbackPrices["claude-opus-4.5"]
+ // Claude 4.7 Opus (暂与4.6同价,待官方定价更新)
+ s.fallbackPrices["claude-opus-4.7"] = s.fallbackPrices["claude-opus-4.6"]
+
// Gemini 3.1 Pro
s.fallbackPrices["gemini-3.1-pro"] = &ModelPricing{
InputPricePerToken: 2e-6, // $2 per MTok
@@ -278,6 +281,9 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing {
// 按模型系列匹配
if strings.Contains(modelLower, "opus") {
+ if strings.Contains(modelLower, "4.7") || strings.Contains(modelLower, "4-7") {
+ return s.fallbackPrices["claude-opus-4.7"]
+ }
if strings.Contains(modelLower, "4.6") || strings.Contains(modelLower, "4-6") {
return s.fallbackPrices["claude-opus-4.6"]
}
diff --git a/backend/internal/service/pricing_service.go b/backend/internal/service/pricing_service.go
index 3b3f31c3..2bf48702 100644
--- a/backend/internal/service/pricing_service.go
+++ b/backend/internal/service/pricing_service.go
@@ -656,65 +656,95 @@ func (s *PricingService) extractBaseName(model string) string {
// matchByModelFamily 基于模型系列匹配
func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
- // Claude模型系列匹配规则
- familyPatterns := map[string][]string{
- "opus-4.6": {"claude-opus-4.6", "claude-opus-4-6"},
- "opus-4.5": {"claude-opus-4.5", "claude-opus-4-5"},
- "opus-4": {"claude-opus-4", "claude-3-opus"},
- "sonnet-4.5": {"claude-sonnet-4.5", "claude-sonnet-4-5"},
- "sonnet-4": {"claude-sonnet-4", "claude-3-5-sonnet"},
- "sonnet-3.5": {"claude-3-5-sonnet", "claude-3.5-sonnet"},
- "sonnet-3": {"claude-3-sonnet"},
- "haiku-3.5": {"claude-3-5-haiku", "claude-3.5-haiku"},
- "haiku-3": {"claude-3-haiku"},
- }
-
- // 确定模型属于哪个系列
- var matchedFamily string
- for family, patterns := range familyPatterns {
- for _, pattern := range patterns {
+ // modelFamily 定义一个模型系列的匹配和定价查找规则。
+ type modelFamily struct {
+ name string // 系列名称
+ match []string // 用于将模型归类到此系列的模式(strings.Contains 匹配)
+ pricing []string // 用于在定价数据中查找价格的模式(nil 则复用 match;可包含低版本 fallback)
+ }
+
+ // 按特异性降序排列:高版本号在前,避免 "claude-opus-4"(opus-4 系列)
+ // 因子串关系误匹配 "claude-opus-4-7"(opus-4.7 系列)。
+ // 注意:原 map 实现存在 Go map 迭代随机性导致的同类 bug,此处改为有序切片修复。
+ families := []modelFamily{
+ {name: "opus-4.7", match: []string{"claude-opus-4-7", "claude-opus-4.7"}, pricing: []string{"claude-opus-4-7", "claude-opus-4.7", "claude-opus-4-6"}},
+ {name: "opus-4.6", match: []string{"claude-opus-4-6", "claude-opus-4.6"}},
+ {name: "opus-4.5", match: []string{"claude-opus-4-5", "claude-opus-4.5"}},
+ {name: "opus-4", match: []string{"claude-opus-4", "claude-3-opus"}},
+ {name: "sonnet-4.5", match: []string{"claude-sonnet-4-5", "claude-sonnet-4.5"}},
+ {name: "sonnet-4", match: []string{"claude-sonnet-4", "claude-3-5-sonnet"}},
+ {name: "sonnet-3.5", match: []string{"claude-3-5-sonnet", "claude-3.5-sonnet"}},
+ {name: "sonnet-3", match: []string{"claude-3-sonnet"}},
+ {name: "haiku-3.5", match: []string{"claude-3-5-haiku", "claude-3.5-haiku"}},
+ {name: "haiku-3", match: []string{"claude-3-haiku"}},
+ }
+
+ // Phase 1: 按有序切片归类(最具体的系列优先匹配)
+ var matched *modelFamily
+ for i := range families {
+ for _, pattern := range families[i].match {
if strings.Contains(model, pattern) || strings.Contains(model, strings.ReplaceAll(pattern, "-", "")) {
- matchedFamily = family
+ matched = &families[i]
break
}
}
- if matchedFamily != "" {
+ if matched != nil {
break
}
}
- if matchedFamily == "" {
- // 简单的系列匹配
- if strings.Contains(model, "opus") {
- if strings.Contains(model, "4.5") || strings.Contains(model, "4-5") {
- matchedFamily = "opus-4.5"
- } else {
- matchedFamily = "opus-4"
+ // Phase 2: 二次兜底——当模型 ID 不含已知模式串时,按关键字粗分
+ if matched == nil {
+ var fallbackName string
+ switch {
+ case strings.Contains(model, "opus"):
+ switch {
+ case strings.Contains(model, "4.7") || strings.Contains(model, "4-7"):
+ fallbackName = "opus-4.7"
+ case strings.Contains(model, "4.6") || strings.Contains(model, "4-6"):
+ fallbackName = "opus-4.6"
+ case strings.Contains(model, "4.5") || strings.Contains(model, "4-5"):
+ fallbackName = "opus-4.5"
+ default:
+ fallbackName = "opus-4"
}
- } else if strings.Contains(model, "sonnet") {
- if strings.Contains(model, "4.5") || strings.Contains(model, "4-5") {
- matchedFamily = "sonnet-4.5"
- } else if strings.Contains(model, "3-5") || strings.Contains(model, "3.5") {
- matchedFamily = "sonnet-3.5"
- } else {
- matchedFamily = "sonnet-4"
+ case strings.Contains(model, "sonnet"):
+ switch {
+ case strings.Contains(model, "4.5") || strings.Contains(model, "4-5"):
+ fallbackName = "sonnet-4.5"
+ case strings.Contains(model, "3-5") || strings.Contains(model, "3.5"):
+ fallbackName = "sonnet-3.5"
+ default:
+ fallbackName = "sonnet-4"
}
- } else if strings.Contains(model, "haiku") {
- if strings.Contains(model, "3-5") || strings.Contains(model, "3.5") {
- matchedFamily = "haiku-3.5"
- } else {
- matchedFamily = "haiku-3"
+ case strings.Contains(model, "haiku"):
+ switch {
+ case strings.Contains(model, "3-5") || strings.Contains(model, "3.5"):
+ fallbackName = "haiku-3.5"
+ default:
+ fallbackName = "haiku-3"
+ }
+ }
+ if fallbackName != "" {
+ for i := range families {
+ if families[i].name == fallbackName {
+ matched = &families[i]
+ break
+ }
}
}
}
- if matchedFamily == "" {
+ if matched == nil {
return nil
}
- // 在价格数据中查找该系列的模型
- patterns := familyPatterns[matchedFamily]
- for _, pattern := range patterns {
+ // Phase 3: 在定价数据中查找该系列的价格
+ lookups := matched.pricing
+ if lookups == nil {
+ lookups = matched.match
+ }
+ for _, pattern := range lookups {
for key, pricing := range s.pricingData {
keyLower := strings.ToLower(key)
if strings.Contains(keyLower, pattern) {
diff --git a/frontend/src/composables/useModelWhitelist.ts b/frontend/src/composables/useModelWhitelist.ts
index d16ed977..a282ae7d 100644
--- a/frontend/src/composables/useModelWhitelist.ts
+++ b/frontend/src/composables/useModelWhitelist.ts
@@ -43,6 +43,7 @@ export const claudeModels = [
'claude-sonnet-4-5-20250929', 'claude-haiku-4-5-20251001',
'claude-opus-4-5-20251101',
'claude-opus-4-6',
+ 'claude-opus-4-7',
'claude-sonnet-4-6',
'claude-2.1', 'claude-2.0', 'claude-instant-1.2'
]
@@ -66,6 +67,7 @@ const antigravityModels = [
// Claude 4.5+ 系列
'claude-opus-4-6',
'claude-opus-4-6-thinking',
+ 'claude-opus-4-7',
'claude-opus-4-5-thinking',
'claude-sonnet-4-6',
'claude-sonnet-4-5',
@@ -250,6 +252,7 @@ const anthropicPresetMappings = [
{ label: 'Sonnet 4.6', from: 'claude-sonnet-4-6', to: 'claude-sonnet-4-6', color: 'bg-indigo-100 text-indigo-700 hover:bg-indigo-200 dark:bg-indigo-900/30 dark:text-indigo-400' },
{ label: 'Opus 4.5', from: 'claude-opus-4-5-20251101', to: 'claude-opus-4-5-20251101', color: 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' },
{ label: 'Opus 4.6', from: 'claude-opus-4-6', to: 'claude-opus-4-6', color: 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' },
+ { label: 'Opus 4.7', from: 'claude-opus-4-7', to: 'claude-opus-4-7', color: 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' },
{ label: 'Haiku 3.5', from: 'claude-3-5-haiku-20241022', to: 'claude-3-5-haiku-20241022', color: 'bg-green-100 text-green-700 hover:bg-green-200 dark:bg-green-900/30 dark:text-green-400' },
{ label: 'Haiku 4.5', from: 'claude-haiku-4-5-20251001', to: 'claude-haiku-4-5-20251001', color: 'bg-emerald-100 text-emerald-700 hover:bg-emerald-200 dark:bg-emerald-900/30 dark:text-emerald-400' },
{ label: 'Opus->Sonnet', from: 'claude-opus-4-6', to: 'claude-sonnet-4-5-20250929', color: 'bg-amber-100 text-amber-700 hover:bg-amber-200 dark:bg-amber-900/30 dark:text-amber-400' }
@@ -309,12 +312,14 @@ const antigravityPresetMappings = [
{ label: 'Sonnet 4.6', from: 'claude-sonnet-4-6', to: 'claude-sonnet-4-6', color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400' },
{ label: 'Sonnet 4.5', from: 'claude-sonnet-4-5', to: 'claude-sonnet-4-5', color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400' },
{ label: 'Opus 4.6', from: 'claude-opus-4-6', to: 'claude-opus-4-6-thinking', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' },
- { label: 'Opus 4.6-thinking', from: 'claude-opus-4-6-thinking', to: 'claude-opus-4-6-thinking', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' }
+ { label: 'Opus 4.6-thinking', from: 'claude-opus-4-6-thinking', to: 'claude-opus-4-6-thinking', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' },
+ { label: 'Opus 4.7', from: 'claude-opus-4-7', to: 'claude-opus-4-7', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' }
]
// Bedrock 预设映射(与后端 DefaultBedrockModelMapping 保持一致)
const bedrockPresetMappings = [
{ label: 'Opus 4.6', from: 'claude-opus-4-6', to: 'us.anthropic.claude-opus-4-6-v1', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' },
+ { label: 'Opus 4.7', from: 'claude-opus-4-7', to: 'us.anthropic.claude-opus-4-7-v1', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' },
{ label: 'Sonnet 4.6', from: 'claude-sonnet-4-6', to: 'us.anthropic.claude-sonnet-4-6', color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400' },
{ label: 'Opus 4.5', from: 'claude-opus-4-5-thinking', to: 'us.anthropic.claude-opus-4-5-20251101-v1:0', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' },
{ label: 'Sonnet 4.5', from: 'claude-sonnet-4-5', to: 'us.anthropic.claude-sonnet-4-5-20250929-v1:0', color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400' },
--
GitLab
From 5d586a9f3a18d724d7883fa636328400c302e261 Mon Sep 17 00:00:00 2001
From: shaw
Date: Fri, 17 Apr 2026 10:17:50 +0800
Subject: [PATCH 010/261] =?UTF-8?q?fix:=20=E4=B8=8A=E6=B8=B8=E8=BF=94?=
=?UTF-8?q?=E5=9B=9E=20KYC=20=E8=BA=AB=E4=BB=BD=E9=AA=8C=E8=AF=81=E8=A6=81?=
=?UTF-8?q?=E6=B1=82=E6=97=B6=E5=81=9C=E6=AD=A2=E8=B4=A6=E5=8F=B7=E8=B0=83?=
=?UTF-8?q?=E5=BA=A6?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
backend/internal/service/ratelimit_service.go | 5 +++++
1 file changed, 5 insertions(+)
diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go
index 4d8009b7..53581574 100644
--- a/backend/internal/service/ratelimit_service.go
+++ b/backend/internal/service/ratelimit_service.go
@@ -152,6 +152,11 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
msg := "Credit balance exhausted (400): " + upstreamMsg
s.handleAuthError(ctx, account, msg)
shouldDisable = true
+ } else if strings.Contains(strings.ToLower(upstreamMsg), "identity verification is required") {
+ // KYC 身份验证要求 → 永久禁用,账号需完成身份验证后才能恢复
+ msg := "Identity verification required (400): " + upstreamMsg
+ s.handleAuthError(ctx, account, msg)
+ shouldDisable = true
}
// 其他 400 错误(如参数问题)不处理,不禁用账号
case 401:
--
GitLab
From 6cfdf4ec051fcac04456367b648622bce175ad7d Mon Sep 17 00:00:00 2001
From: "github-actions[bot]"
<41898282+github-actions[bot]@users.noreply.github.com>
Date: Fri, 17 Apr 2026 02:51:18 +0000
Subject: [PATCH 011/261] chore: sync VERSION to 0.1.114 [skip ci]
---
backend/cmd/server/VERSION | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION
index c21e67e6..c29f5f75 100644
--- a/backend/cmd/server/VERSION
+++ b/backend/cmd/server/VERSION
@@ -1 +1 @@
-0.1.113
+0.1.114
--
GitLab
From fd0c9a130530eeab28c212cec601a3ef5ba1f843 Mon Sep 17 00:00:00 2001
From: erio
Date: Fri, 17 Apr 2026 17:00:29 +0800
Subject: [PATCH 012/261] fix(payment): store provider config as plaintext JSON
with legacy ciphertext fallback
Without TOTP_ENCRYPTION_KEY, saved payment configs were lost on restart because
the AES round-trip failed silently. Write new records as plaintext JSON; read
path tries JSON first, falls back to legacy AES decrypt when a key is present,
and treats unreadable values as empty so admins can re-enter them via the UI.
---
backend/internal/payment/crypto.go | 11 ++-
backend/internal/payment/load_balancer.go | 30 ++++--
.../internal/payment/load_balancer_test.go | 97 +++++++++++++++++++
.../service/payment_config_providers.go | 38 +++++---
4 files changed, 151 insertions(+), 25 deletions(-)
diff --git a/backend/internal/payment/crypto.go b/backend/internal/payment/crypto.go
index e39e957f..5467e50b 100644
--- a/backend/internal/payment/crypto.go
+++ b/backend/internal/payment/crypto.go
@@ -10,12 +10,15 @@ import (
"strings"
)
+// AES256KeySize is the required key length (in bytes) for AES-256-GCM.
+const AES256KeySize = 32
+
// Encrypt encrypts plaintext using AES-256-GCM with the given 32-byte key.
// The output format is "iv:authTag:ciphertext" where each component is base64-encoded,
// matching the Node.js crypto.ts format for cross-compatibility.
func Encrypt(plaintext string, key []byte) (string, error) {
- if len(key) != 32 {
- return "", fmt.Errorf("encryption key must be 32 bytes, got %d", len(key))
+ if len(key) != AES256KeySize {
+ return "", fmt.Errorf("encryption key must be %d bytes, got %d", AES256KeySize, len(key))
}
block, err := aes.NewCipher(key)
@@ -52,8 +55,8 @@ func Encrypt(plaintext string, key []byte) (string, error) {
// Decrypt decrypts a ciphertext string produced by Encrypt.
// The input format is "iv:authTag:ciphertext" where each component is base64-encoded.
func Decrypt(ciphertext string, key []byte) (string, error) {
- if len(key) != 32 {
- return "", fmt.Errorf("encryption key must be 32 bytes, got %d", len(key))
+ if len(key) != AES256KeySize {
+ return "", fmt.Errorf("encryption key must be %d bytes, got %d", AES256KeySize, len(key))
}
parts := strings.SplitN(ciphertext, ":", 3)
diff --git a/backend/internal/payment/load_balancer.go b/backend/internal/payment/load_balancer.go
index f0353173..52a1b011 100644
--- a/backend/internal/payment/load_balancer.go
+++ b/backend/internal/payment/load_balancer.go
@@ -261,6 +261,9 @@ func (lb *DefaultLoadBalancer) buildSelection(selected *dbent.PaymentProviderIns
if err != nil {
return nil, fmt.Errorf("decrypt instance %d config: %w", selected.ID, err)
}
+ if config == nil {
+ config = map[string]string{}
+ }
if selected.PaymentMode != "" {
config["paymentMode"] = selected.PaymentMode
@@ -275,16 +278,29 @@ func (lb *DefaultLoadBalancer) buildSelection(selected *dbent.PaymentProviderIns
}, nil
}
-func (lb *DefaultLoadBalancer) decryptConfig(encrypted string) (map[string]string, error) {
- plaintext, err := Decrypt(encrypted, lb.encryptionKey)
- if err != nil {
- return nil, err
+// decryptConfig parses a stored provider config.
+// New records are plaintext JSON; legacy records are AES-256-GCM ciphertext.
+// Unreadable values (legacy ciphertext without a valid key, or malformed data)
+// are treated as empty so the service keeps running while the admin re-enters
+// the config via the UI.
+func (lb *DefaultLoadBalancer) decryptConfig(stored string) (map[string]string, error) {
+ if stored == "" {
+ return nil, nil
}
var config map[string]string
- if err := json.Unmarshal([]byte(plaintext), &config); err != nil {
- return nil, fmt.Errorf("unmarshal config: %w", err)
+ if err := json.Unmarshal([]byte(stored), &config); err == nil {
+ return config, nil
+ }
+ if len(lb.encryptionKey) == AES256KeySize {
+ if plaintext, err := Decrypt(stored, lb.encryptionKey); err == nil {
+ if err := json.Unmarshal([]byte(plaintext), &config); err == nil {
+ return config, nil
+ }
+ }
}
- return config, nil
+ slog.Warn("payment provider config unreadable, treating as empty for re-entry",
+ "stored_len", len(stored))
+ return nil, nil
}
// GetInstanceDailyAmount returns the total completed order amount for an instance today.
diff --git a/backend/internal/payment/load_balancer_test.go b/backend/internal/payment/load_balancer_test.go
index 04b3c25b..2bf4f6ac 100644
--- a/backend/internal/payment/load_balancer_test.go
+++ b/backend/internal/payment/load_balancer_test.go
@@ -452,6 +452,103 @@ func TestStartOfDay(t *testing.T) {
}
}
+func TestDecryptConfig_PlaintextAndLegacyCompat(t *testing.T) {
+ t.Parallel()
+
+ key := make([]byte, AES256KeySize)
+ for i := range key {
+ key[i] = byte(i + 1)
+ }
+ wrongKey := make([]byte, AES256KeySize)
+ for i := range wrongKey {
+ wrongKey[i] = byte(0xFF - i)
+ }
+
+ plaintextJSON := `{"appId":"app-123","secret":"sec-xyz"}`
+
+ legacyEncrypted, err := Encrypt(plaintextJSON, key)
+ if err != nil {
+ t.Fatalf("seed Encrypt: %v", err)
+ }
+
+ tests := []struct {
+ name string
+ stored string
+ key []byte
+ want map[string]string
+ }{
+ {
+ name: "empty stored returns nil map",
+ stored: "",
+ key: key,
+ want: nil,
+ },
+ {
+ name: "plaintext JSON parses directly",
+ stored: plaintextJSON,
+ key: nil,
+ want: map[string]string{"appId": "app-123", "secret": "sec-xyz"},
+ },
+ {
+ name: "plaintext JSON works even with key present",
+ stored: plaintextJSON,
+ key: key,
+ want: map[string]string{"appId": "app-123", "secret": "sec-xyz"},
+ },
+ {
+ name: "legacy ciphertext with correct key decrypts",
+ stored: legacyEncrypted,
+ key: key,
+ want: map[string]string{"appId": "app-123", "secret": "sec-xyz"},
+ },
+ {
+ name: "legacy ciphertext with no key treated as empty",
+ stored: legacyEncrypted,
+ key: nil,
+ want: nil,
+ },
+ {
+ name: "legacy ciphertext with wrong key treated as empty",
+ stored: legacyEncrypted,
+ key: wrongKey,
+ want: nil,
+ },
+ {
+ name: "garbage data treated as empty",
+ stored: "not-json-and-not-ciphertext",
+ key: key,
+ want: nil,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ lb := NewDefaultLoadBalancer(nil, tt.key)
+ got, err := lb.decryptConfig(tt.stored)
+ if err != nil {
+ t.Fatalf("decryptConfig unexpected error: %v", err)
+ }
+ if !stringMapEqual(got, tt.want) {
+ t.Fatalf("decryptConfig = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+// stringMapEqual compares two map[string]string values; nil and empty are equal.
+func stringMapEqual(a, b map[string]string) bool {
+ if len(a) != len(b) {
+ return false
+ }
+ for k, v := range a {
+ if bv, ok := b[k]; !ok || bv != v {
+ return false
+ }
+ }
+ return true
+}
+
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
diff --git a/backend/internal/service/payment_config_providers.go b/backend/internal/service/payment_config_providers.go
index 3c406b45..59337ad6 100644
--- a/backend/internal/service/payment_config_providers.go
+++ b/backend/internal/service/payment_config_providers.go
@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
+ "log/slog"
"strconv"
"strings"
@@ -290,19 +291,29 @@ func (s *PaymentConfigService) mergeConfig(ctx context.Context, id int64, newCon
return existing, nil
}
-func (s *PaymentConfigService) decryptConfig(encrypted string) (map[string]string, error) {
- if encrypted == "" {
+// decryptConfig parses a stored provider config.
+// New records are plaintext JSON; legacy records are AES-256-GCM ciphertext
+// ("iv:authTag:ciphertext"). Values that cannot be parsed as either — including
+// legacy ciphertext with no/invalid TOTP_ENCRYPTION_KEY — are treated as empty,
+// letting the admin re-enter the config via the UI to complete the migration.
+func (s *PaymentConfigService) decryptConfig(stored string) (map[string]string, error) {
+ if stored == "" {
return nil, nil
}
- decrypted, err := payment.Decrypt(encrypted, s.encryptionKey)
- if err != nil {
- return nil, fmt.Errorf("decrypt config: %w", err)
+ var cfg map[string]string
+ if err := json.Unmarshal([]byte(stored), &cfg); err == nil {
+ return cfg, nil
}
- var raw map[string]string
- if err := json.Unmarshal([]byte(decrypted), &raw); err != nil {
- return nil, fmt.Errorf("unmarshal decrypted config: %w", err)
+ if len(s.encryptionKey) == payment.AES256KeySize {
+ if plaintext, err := payment.Decrypt(stored, s.encryptionKey); err == nil {
+ if err := json.Unmarshal([]byte(plaintext), &cfg); err == nil {
+ return cfg, nil
+ }
+ }
}
- return raw, nil
+ slog.Warn("payment provider config unreadable, treating as empty for re-entry",
+ "stored_len", len(stored))
+ return nil, nil
}
func (s *PaymentConfigService) DeleteProviderInstance(ctx context.Context, id int64) error {
@@ -317,14 +328,13 @@ func (s *PaymentConfigService) DeleteProviderInstance(ctx context.Context, id in
return s.entClient.PaymentProviderInstance.DeleteOneID(id).Exec(ctx)
}
+// encryptConfig serialises a provider config for storage.
+// New records are written as plaintext JSON; the historical AES-GCM wrapping
+// has been dropped but decryptConfig still accepts old ciphertext during migration.
func (s *PaymentConfigService) encryptConfig(cfg map[string]string) (string, error) {
data, err := json.Marshal(cfg)
if err != nil {
return "", fmt.Errorf("marshal config: %w", err)
}
- enc, err := payment.Encrypt(string(data), s.encryptionKey)
- if err != nil {
- return "", fmt.Errorf("encrypt config: %w", err)
- }
- return enc, nil
+ return string(data), nil
}
--
GitLab
From 44cdef7934168f167c5d433e5947aea3ac5a279d Mon Sep 17 00:00:00 2001
From: erio
Date: Fri, 17 Apr 2026 17:00:45 +0800
Subject: [PATCH 013/261] fix(usage): subscription billing honours group rate
multiplier
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Subscription-mode billing was consuming quota at TotalCost (raw) instead of
ActualCost (TotalCost * RateMultiplier), so per-group rate multipliers —
including free subscriptions (multiplier = 0) — were silently ignored.
Switch the three subscription cost writes in buildUsageBillingCommand,
finalizePostUsageBilling, and the legacy postUsageBilling fallback to
ActualCost, and add a table-driven test covering 2x / 0.5x / free multipliers
plus a balance-mode regression check.
---
backend/internal/service/gateway_service.go | 16 ++--
...teway_service_subscription_billing_test.go | 85 +++++++++++++++++++
2 files changed, 96 insertions(+), 5 deletions(-)
create mode 100644 backend/internal/service/gateway_service_subscription_billing_test.go
diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go
index 4b4fc0bf..07a9e41c 100644
--- a/backend/internal/service/gateway_service.go
+++ b/backend/internal/service/gateway_service.go
@@ -7317,8 +7317,10 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill
cost := p.Cost
if p.IsSubscriptionBill {
- if cost.TotalCost > 0 {
- if err := deps.userSubRepo.IncrementUsage(billingCtx, p.Subscription.ID, cost.TotalCost); err != nil {
+ // Subscription usage tracked by ActualCost so group rate multiplier
+ // consumes the quota at the expected speed.
+ if cost.ActualCost > 0 {
+ if err := deps.userSubRepo.IncrementUsage(billingCtx, p.Subscription.ID, cost.ActualCost); err != nil {
slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err)
}
}
@@ -7417,9 +7419,13 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage
}
}
+ // Record subscription / balance cost using ActualCost so the group (and any
+ // user-specific) rate multiplier consumes subscription quota at the expected
+ // speed. TotalCost remains the raw (pre-multiplier) value; downstream guards
+ // on "> 0" still correctly skip free subscriptions (RateMultiplier == 0).
if p.IsSubscriptionBill && p.Subscription != nil && p.Cost.TotalCost > 0 {
cmd.SubscriptionID = &p.Subscription.ID
- cmd.SubscriptionCost = p.Cost.TotalCost
+ cmd.SubscriptionCost = p.Cost.ActualCost
} else if p.Cost.ActualCost > 0 {
cmd.BalanceCost = p.Cost.ActualCost
}
@@ -7478,8 +7484,8 @@ func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps, resu
}
if p.IsSubscriptionBill {
- if p.Cost.TotalCost > 0 && p.User != nil && p.APIKey != nil && p.APIKey.GroupID != nil {
- deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, p.Cost.TotalCost)
+ if p.Cost.ActualCost > 0 && p.User != nil && p.APIKey != nil && p.APIKey.GroupID != nil {
+ deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, p.Cost.ActualCost)
}
} else if p.Cost.ActualCost > 0 && p.User != nil {
deps.billingCacheService.QueueDeductBalance(p.User.ID, p.Cost.ActualCost)
diff --git a/backend/internal/service/gateway_service_subscription_billing_test.go b/backend/internal/service/gateway_service_subscription_billing_test.go
new file mode 100644
index 00000000..42a81035
--- /dev/null
+++ b/backend/internal/service/gateway_service_subscription_billing_test.go
@@ -0,0 +1,85 @@
+//go:build unit
+
+package service
+
+import (
+ "testing"
+)
+
+// TestBuildUsageBillingCommand_SubscriptionAppliesRateMultiplier locks in the fix
+// that subscription-mode billing honours the group (and any user-specific) rate
+// multiplier — i.e. cmd.SubscriptionCost tracks ActualCost (= TotalCost *
+// RateMultiplier), not raw TotalCost.
+func TestBuildUsageBillingCommand_SubscriptionAppliesRateMultiplier(t *testing.T) {
+ t.Parallel()
+
+ groupID := int64(7)
+ subID := int64(42)
+
+ tests := []struct {
+ name string
+ totalCost float64
+ actualCost float64
+ isSubscription bool
+ wantSub float64
+ wantBalance float64
+ }{
+ {
+ name: "subscription with 2x multiplier consumes 2x quota",
+ totalCost: 1.0,
+ actualCost: 2.0,
+ isSubscription: true,
+ wantSub: 2.0,
+ wantBalance: 0,
+ },
+ {
+ name: "subscription with 0.5x multiplier consumes 0.5x quota",
+ totalCost: 1.0,
+ actualCost: 0.5,
+ isSubscription: true,
+ wantSub: 0.5,
+ wantBalance: 0,
+ },
+ {
+ name: "free subscription (multiplier 0) consumes no quota",
+ totalCost: 1.0,
+ actualCost: 0,
+ isSubscription: true,
+ wantSub: 0,
+ wantBalance: 0,
+ },
+ {
+ name: "balance billing keeps using ActualCost (regression)",
+ totalCost: 1.0,
+ actualCost: 2.0,
+ isSubscription: false,
+ wantSub: 0,
+ wantBalance: 2.0,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ p := &postUsageBillingParams{
+ Cost: &CostBreakdown{TotalCost: tt.totalCost, ActualCost: tt.actualCost},
+ User: &User{ID: 1},
+ APIKey: &APIKey{ID: 2, GroupID: &groupID},
+ Account: &Account{ID: 3},
+ Subscription: &UserSubscription{ID: subID},
+ IsSubscriptionBill: tt.isSubscription,
+ }
+
+ cmd := buildUsageBillingCommand("req-1", nil, p)
+ if cmd == nil {
+ t.Fatal("buildUsageBillingCommand returned nil")
+ }
+ if cmd.SubscriptionCost != tt.wantSub {
+ t.Errorf("SubscriptionCost = %v, want %v", cmd.SubscriptionCost, tt.wantSub)
+ }
+ if cmd.BalanceCost != tt.wantBalance {
+ t.Errorf("BalanceCost = %v, want %v", cmd.BalanceCost, tt.wantBalance)
+ }
+ })
+ }
+}
--
GitLab
From 948d8e6d024412cc2efda34ec41540fddb4bfb0b Mon Sep 17 00:00:00 2001
From: erio
Date: Fri, 17 Apr 2026 17:01:01 +0800
Subject: [PATCH 014/261] fix(admin): prevent browser password manager from
autofilling account API key
Chrome's password manager matched the apikey-type account's Base URL + API Key
inputs as a login form and autofilled the last saved password by domain, so
editing a Gemini account could overwrite its apikey with a Claude key that
shared the same Base URL. Add autocomplete="new-password" plus data-*-ignore
attributes for 1Password / LastPass / Bitwarden to opt the field out of every
major password manager's autofill.
---
frontend/src/components/account/EditAccountModal.vue | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue
index 1da32e2c..59ca0b9c 100644
--- a/frontend/src/components/account/EditAccountModal.vue
+++ b/frontend/src/components/account/EditAccountModal.vue
@@ -52,6 +52,10 @@
v-model="editApiKey"
type="password"
class="input font-mono"
+ autocomplete="new-password"
+ data-1p-ignore
+ data-lpignore="true"
+ data-bwignore="true"
:placeholder="
account.platform === 'openai'
? 'sk-proj-...'
--
GitLab
From df57d2776b1a74e37470970f1bcf8942ad810e1d Mon Sep 17 00:00:00 2001
From: erio
Date: Fri, 17 Apr 2026 18:32:12 +0800
Subject: [PATCH 015/261] fix(billing): reject rate_multiplier <= 0 on save;
clamp negatives to 0 in compute
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
分组倍率和用户专属倍率在保存时没有校验,0 会触发计费层的 `<=0 → 1.0`
防御条款,结果订阅/余额分组按标准价扣费;完全是沉默地绕过了业务规则。
- 保存校验(admin_service):CreateGroup / UpdateGroup / BatchSetGroupRateMultipliers /
UpdateUser.SyncUserGroupRates 全部要求 > 0
- 计算层(billing_service):三处 `<=0 → 1.0` 改为 `<0 → 0`;负数按 0 结算,
避免配置异常被静默按 1x 收费
- 前端:分组倍率 / 用户专属倍率输入 min 统一到 0.001
- 删除未使用的 IsFreeSubscription 方法
测试:新增 billing_service_rate_multiplier_test.go 端到端验证;更新原有锁定
旧 `<=0 → 1.0` 行为的测试。
---
backend/internal/service/admin_service.go | 21 +++++++
.../service/admin_service_group_test.go | 6 ++
backend/internal/service/billing_service.go | 16 ++---
.../service/billing_service_image_test.go | 5 +-
.../billing_service_rate_multiplier_test.go | 63 +++++++++++++++++++
.../internal/service/billing_service_test.go | 28 ---------
.../service/billing_service_unified_test.go | 40 ++++--------
backend/internal/service/group.go | 4 --
.../openai_gateway_record_usage_test.go | 2 +-
.../admin/group/GroupRateMultipliersModal.vue | 2 +-
.../admin/user/UserAllowedGroupsModal.vue | 4 +-
11 files changed, 119 insertions(+), 72 deletions(-)
create mode 100644 backend/internal/service/billing_service_rate_multiplier_test.go
diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go
index 7c26a47c..701f3659 100644
--- a/backend/internal/service/admin_service.go
+++ b/backend/internal/service/admin_service.go
@@ -586,6 +586,15 @@ func (s *adminServiceImpl) assignDefaultSubscriptions(ctx context.Context, userI
}
func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) {
+ // 校验用户专属分组倍率:必须 > 0(nil 合法,表示清除专属倍率)
+ if input.GroupRates != nil {
+ for groupID, rate := range input.GroupRates {
+ if rate != nil && *rate <= 0 {
+ return nil, fmt.Errorf("rate_multiplier must be > 0 (group_id=%d)", groupID)
+ }
+ }
+ }
+
user, err := s.userRepo.GetByID(ctx, id)
if err != nil {
return nil, err
@@ -811,6 +820,10 @@ func (s *adminServiceImpl) GetGroup(ctx context.Context, id int64) (*Group, erro
}
func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error) {
+ if input.RateMultiplier <= 0 {
+ return nil, errors.New("rate_multiplier must be > 0")
+ }
+
platform := input.Platform
if platform == "" {
platform = PlatformAnthropic
@@ -1050,6 +1063,9 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
group.Platform = input.Platform
}
if input.RateMultiplier != nil {
+ if *input.RateMultiplier <= 0 {
+ return nil, errors.New("rate_multiplier must be > 0")
+ }
group.RateMultiplier = *input.RateMultiplier
}
if input.IsExclusive != nil {
@@ -1286,6 +1302,11 @@ func (s *adminServiceImpl) BatchSetGroupRateMultipliers(ctx context.Context, gro
if s.userGroupRateRepo == nil {
return nil
}
+ for _, e := range entries {
+ if e.RateMultiplier <= 0 {
+ return fmt.Errorf("rate_multiplier must be > 0 (user_id=%d)", e.UserID)
+ }
+ }
return s.userGroupRateRepo.SyncGroupRateMultipliers(ctx, groupID, entries)
}
diff --git a/backend/internal/service/admin_service_group_test.go b/backend/internal/service/admin_service_group_test.go
index a4c6d0ca..41d2c26a 100644
--- a/backend/internal/service/admin_service_group_test.go
+++ b/backend/internal/service/admin_service_group_test.go
@@ -621,6 +621,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsUnsupportedPlatfo
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformOpenAI,
+ RateMultiplier: 1.0,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
@@ -641,6 +642,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsSubscription(t *t
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformAnthropic,
+ RateMultiplier: 1.0,
SubscriptionType: SubscriptionTypeSubscription,
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
@@ -695,6 +697,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsFallbackGroup(t *
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformAnthropic,
+ RateMultiplier: 1.0,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
@@ -713,6 +716,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackNotFound(t *testing.T) {
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformAnthropic,
+ RateMultiplier: 1.0,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
@@ -733,6 +737,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackAllowsAntigravity(t *tes
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformAntigravity,
+ RateMultiplier: 1.0,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
@@ -750,6 +755,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackClearsOnZero(t *testing.
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformAnthropic,
+ RateMultiplier: 1.0,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &zero,
})
diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go
index 32a54cbe..c9f32b3b 100644
--- a/backend/internal/service/billing_service.go
+++ b/backend/internal/service/billing_service.go
@@ -448,8 +448,9 @@ func (s *BillingService) CalculateCostUnified(input CostInput) (*CostBreakdown,
})
}
- if input.RateMultiplier <= 0 {
- input.RateMultiplier = 1.0
+ // 保存时强制 > 0;若仍有负数泄漏(缓存/迁移残留),按 0 处理避免按 1x 误扣。
+ if input.RateMultiplier < 0 {
+ input.RateMultiplier = 0
}
var breakdown *CostBreakdown
@@ -493,8 +494,9 @@ func (s *BillingService) computeTokenBreakdown(
rateMultiplier float64, serviceTier string,
applyLongCtx bool,
) *CostBreakdown {
- if rateMultiplier <= 0 {
- rateMultiplier = 1.0
+ // 保存时强制 > 0;若仍有负数泄漏,按 0 处理避免按 1x 误扣。
+ if rateMultiplier < 0 {
+ rateMultiplier = 0
}
inputPrice := pricing.InputPricePerToken
@@ -831,9 +833,9 @@ func (s *BillingService) CalculateImageCost(model string, imageSize string, imag
// 计算总费用
totalCost := unitPrice * float64(imageCount)
- // 应用倍率
- if rateMultiplier <= 0 {
- rateMultiplier = 1.0
+ // 应用倍率(保存时强制 > 0;负数按 0 处理避免按 1x 误扣)
+ if rateMultiplier < 0 {
+ rateMultiplier = 0
}
actualCost := totalCost * rateMultiplier
diff --git a/backend/internal/service/billing_service_image_test.go b/backend/internal/service/billing_service_image_test.go
index fa90f6bb..8d3ca987 100644
--- a/backend/internal/service/billing_service_image_test.go
+++ b/backend/internal/service/billing_service_image_test.go
@@ -90,13 +90,14 @@ func TestCalculateImageCost_NegativeCount(t *testing.T) {
require.Equal(t, 0.0, cost.ActualCost)
}
-// TestCalculateImageCost_ZeroRateMultiplier 测试费率倍数为 0 时默认使用 1.0
+// TestCalculateImageCost_ZeroRateMultiplier 锁定新行为:倍率 0 直接按 0 计费
+// (保存时已强制 > 0;若仍有 0 泄漏到计费层,零消耗比历史的 1.0 更安全)。
func TestCalculateImageCost_ZeroRateMultiplier(t *testing.T) {
svc := &BillingService{}
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 0)
require.InDelta(t, 0.201, cost.TotalCost, 0.0001)
- require.InDelta(t, 0.201, cost.ActualCost, 0.0001) // 0 倍率当作 1.0 处理
+ require.InDelta(t, 0.0, cost.ActualCost, 1e-10)
}
// TestGetImageUnitPrice_GroupPriorityOverDefault 测试分组价格优先于默认价格
diff --git a/backend/internal/service/billing_service_rate_multiplier_test.go b/backend/internal/service/billing_service_rate_multiplier_test.go
new file mode 100644
index 00000000..83788196
--- /dev/null
+++ b/backend/internal/service/billing_service_rate_multiplier_test.go
@@ -0,0 +1,63 @@
+//go:build unit
+
+package service
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+// TestCalculateCost_RateMultiplier_NegativeClampedToZero 锁定负数倍率被
+// 钳制为 0(而非历史上的 1.0),避免配置异常导致静默按标准价扣费。
+func TestCalculateCost_RateMultiplier_NegativeClampedToZero(t *testing.T) {
+ svc := newTestBillingService()
+ tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500}
+
+ tests := []struct {
+ name string
+ multiplier float64
+ wantRatio float64 // ActualCost / TotalCost
+ }{
+ {"negative clamped to 0", -1.5, 0},
+ {"zero passes through as 0 (defense in depth)", 0, 0},
+ {"positive 2x applied", 2.0, 2.0},
+ {"positive 0.5x applied", 0.5, 0.5},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cost, err := svc.CalculateCost("claude-sonnet-4", tokens, tt.multiplier)
+ require.NoError(t, err)
+ require.Greater(t, cost.TotalCost, 0.0, "TotalCost should be non-zero")
+ require.InDelta(t, tt.wantRatio*cost.TotalCost, cost.ActualCost, 1e-9)
+ })
+ }
+}
+
+// TestCalculateImageCost_RateMultiplier_NegativeClampedToZero 图片按次计费路径
+// 同样遵循"负数 → 0"语义。
+func TestCalculateImageCost_RateMultiplier_NegativeClampedToZero(t *testing.T) {
+ svc := newTestBillingService()
+ price := 0.04
+ cfg := &ImagePriceConfig{Price1K: &price}
+
+ tests := []struct {
+ name string
+ multiplier float64
+ wantRatio float64
+ }{
+ {"negative clamped to 0", -0.5, 0},
+ {"zero passes through", 0, 0},
+ {"positive 3x applied", 3.0, 3.0},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cost := svc.CalculateImageCost("imagen-3", "1K", 2, cfg, tt.multiplier)
+ require.NotNil(t, cost)
+ require.Greater(t, cost.TotalCost, 0.0)
+ require.InDelta(t, tt.wantRatio*cost.TotalCost, cost.ActualCost, 1e-9)
+ })
+ }
+}
diff --git a/backend/internal/service/billing_service_test.go b/backend/internal/service/billing_service_test.go
index 2cf134e2..fc8361c7 100644
--- a/backend/internal/service/billing_service_test.go
+++ b/backend/internal/service/billing_service_test.go
@@ -71,34 +71,6 @@ func TestCalculateCost_RateMultiplier(t *testing.T) {
require.InDelta(t, cost1x.ActualCost*2, cost2x.ActualCost, 1e-10)
}
-func TestCalculateCost_ZeroMultiplierDefaultsToOne(t *testing.T) {
- svc := newTestBillingService()
-
- tokens := UsageTokens{InputTokens: 1000}
-
- costZero, err := svc.CalculateCost("claude-sonnet-4", tokens, 0)
- require.NoError(t, err)
-
- costOne, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
- require.NoError(t, err)
-
- require.InDelta(t, costOne.ActualCost, costZero.ActualCost, 1e-10)
-}
-
-func TestCalculateCost_NegativeMultiplierDefaultsToOne(t *testing.T) {
- svc := newTestBillingService()
-
- tokens := UsageTokens{InputTokens: 1000}
-
- costNeg, err := svc.CalculateCost("claude-sonnet-4", tokens, -1.0)
- require.NoError(t, err)
-
- costOne, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
- require.NoError(t, err)
-
- require.InDelta(t, costOne.ActualCost, costNeg.ActualCost, 1e-10)
-}
-
func TestGetModelPricing_FallbackMatchesByFamily(t *testing.T) {
svc := newTestBillingService()
diff --git a/backend/internal/service/billing_service_unified_test.go b/backend/internal/service/billing_service_unified_test.go
index 694c3384..e6a92d1a 100644
--- a/backend/internal/service/billing_service_unified_test.go
+++ b/backend/internal/service/billing_service_unified_test.go
@@ -147,40 +147,35 @@ func TestCalculateCostUnified_ImageMode(t *testing.T) {
require.Equal(t, string(BillingModeImage), cost.BillingMode)
}
-func TestCalculateCostUnified_RateMultiplierZeroDefaultsToOne(t *testing.T) {
+// TestCalculateCostUnified_RateMultiplierZeroProducesZero 锁定新行为:
+// 保存时强制 > 0;若 0 仍泄漏到计费层,按 0 计费(而非历史上的 1.0)。
+func TestCalculateCostUnified_RateMultiplierZeroProducesZero(t *testing.T) {
bs := newTestBillingService()
resolver := NewModelPricingResolver(nil, bs)
tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500}
- costZero, err := bs.CalculateCostUnified(CostInput{
- Ctx: context.Background(),
- Model: "claude-sonnet-4",
- Tokens: tokens,
- RateMultiplier: 0, // should default to 1.0
- Resolver: resolver,
- })
- require.NoError(t, err)
-
- costOne, err := bs.CalculateCostUnified(CostInput{
+ cost, err := bs.CalculateCostUnified(CostInput{
Ctx: context.Background(),
Model: "claude-sonnet-4",
Tokens: tokens,
- RateMultiplier: 1.0,
+ RateMultiplier: 0,
Resolver: resolver,
})
require.NoError(t, err)
-
- require.InDelta(t, costOne.ActualCost, costZero.ActualCost, 1e-10)
+ require.Greater(t, cost.TotalCost, 0.0)
+ require.InDelta(t, 0.0, cost.ActualCost, 1e-10)
}
-func TestCalculateCostUnified_NegativeRateMultiplierDefaultsToOne(t *testing.T) {
+// TestCalculateCostUnified_NegativeRateMultiplierClampedToZero 锁定新行为:
+// 负数倍率按 0 计费,避免历史的 <=0 → 1.0 把配置异常静默按标准价扣费。
+func TestCalculateCostUnified_NegativeRateMultiplierClampedToZero(t *testing.T) {
bs := newTestBillingService()
resolver := NewModelPricingResolver(nil, bs)
tokens := UsageTokens{InputTokens: 1000}
- costNeg, err := bs.CalculateCostUnified(CostInput{
+ cost, err := bs.CalculateCostUnified(CostInput{
Ctx: context.Background(),
Model: "claude-sonnet-4",
Tokens: tokens,
@@ -188,17 +183,8 @@ func TestCalculateCostUnified_NegativeRateMultiplierDefaultsToOne(t *testing.T)
Resolver: resolver,
})
require.NoError(t, err)
-
- costOne, err := bs.CalculateCostUnified(CostInput{
- Ctx: context.Background(),
- Model: "claude-sonnet-4",
- Tokens: tokens,
- RateMultiplier: 1.0,
- Resolver: resolver,
- })
- require.NoError(t, err)
-
- require.InDelta(t, costOne.ActualCost, costNeg.ActualCost, 1e-10)
+ require.Greater(t, cost.TotalCost, 0.0)
+ require.InDelta(t, 0.0, cost.ActualCost, 1e-10)
}
func TestCalculateCostUnified_BillingModeFieldFilled(t *testing.T) {
diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go
index 12262613..64434ae1 100644
--- a/backend/internal/service/group.go
+++ b/backend/internal/service/group.go
@@ -76,10 +76,6 @@ func (g *Group) IsSubscriptionType() bool {
return g.SubscriptionType == SubscriptionTypeSubscription
}
-func (g *Group) IsFreeSubscription() bool {
- return g.IsSubscriptionType() && g.RateMultiplier == 0
-}
-
func (g *Group) HasDailyLimit() bool {
return g.DailyLimitUSD != nil && *g.DailyLimitUSD > 0
}
diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go
index e6fa94aa..6fa8a5bd 100644
--- a/backend/internal/service/openai_gateway_record_usage_test.go
+++ b/backend/internal/service/openai_gateway_record_usage_test.go
@@ -1031,7 +1031,7 @@ func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSetsSubscriptionFiel
Model: "gpt-5.1",
Duration: time.Second,
},
- APIKey: &APIKey{ID: 100, GroupID: i64p(88), Group: &Group{ID: 88, SubscriptionType: SubscriptionTypeSubscription}},
+ APIKey: &APIKey{ID: 100, GroupID: i64p(88), Group: &Group{ID: 88, SubscriptionType: SubscriptionTypeSubscription, RateMultiplier: 1.0}},
User: &User{ID: 200},
Account: &Account{ID: 300},
Subscription: subscription,
diff --git a/frontend/src/components/admin/group/GroupRateMultipliersModal.vue b/frontend/src/components/admin/group/GroupRateMultipliersModal.vue
index bf79bea2..41b2e63c 100644
--- a/frontend/src/components/admin/group/GroupRateMultipliersModal.vue
+++ b/frontend/src/components/admin/group/GroupRateMultipliersModal.vue
@@ -166,7 +166,7 @@
Date: Fri, 17 Apr 2026 22:07:15 +0800
Subject: [PATCH 016/261] feat(gateway): raise upstream response read limit 8MB
-> 128MB (configurable)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
图片生成 API 返回的 base64 内联图响应经常超过 8MB 单次读取上限,被
ReadUpstreamResponseBody 拦截成 502 upstream_error。
单张 4K PNG base64 最坏约 67MB,多张候选图或 imageSize=4K 的 image_generation
一次请求能轻松到 30MB+。把默认上限提到 128MB 能覆盖 2-3 张 4K 图,相对
请求体上限 256MB 仍有缓冲;同时抽出 config.DefaultUpstreamResponseReadMaxBytes
共享常量,viper 默认值和 service 层兜底共用,消除两处同步魔法数字。
仍可通过 gateway.upstream_response_read_max_bytes 配置项覆盖。
---
backend/internal/config/config.go | 7 ++++++-
backend/internal/service/upstream_response_limit.go | 4 +++-
2 files changed, 9 insertions(+), 2 deletions(-)
diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go
index dd9a4e58..15592905 100644
--- a/backend/internal/config/config.go
+++ b/backend/internal/config/config.go
@@ -52,6 +52,11 @@ const (
ConnectionPoolIsolationAccountProxy = "account_proxy"
)
+// DefaultUpstreamResponseReadMaxBytes 上游非流式响应体的默认读取上限。
+// 128 MB 可容纳 2-3 张 4K PNG(base64 膨胀 33%,单张 4K PNG 最坏约 67MB base64)。
+// 可通过 gateway.upstream_response_read_max_bytes 配置项覆盖。
+const DefaultUpstreamResponseReadMaxBytes int64 = 128 * 1024 * 1024
+
type Config struct {
Server ServerConfig `mapstructure:"server"`
Log LogConfig `mapstructure:"log"`
@@ -1407,7 +1412,7 @@ func setDefaults() {
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
viper.SetDefault("gateway.antigravity_extra_retries", 10)
viper.SetDefault("gateway.max_body_size", int64(256*1024*1024))
- viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024))
+ viper.SetDefault("gateway.upstream_response_read_max_bytes", DefaultUpstreamResponseReadMaxBytes)
viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024))
viper.SetDefault("gateway.gemini_debug_response_headers", false)
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
diff --git a/backend/internal/service/upstream_response_limit.go b/backend/internal/service/upstream_response_limit.go
index a0444d52..ddf0e818 100644
--- a/backend/internal/service/upstream_response_limit.go
+++ b/backend/internal/service/upstream_response_limit.go
@@ -12,7 +12,9 @@ import (
var ErrUpstreamResponseBodyTooLarge = errors.New("upstream response body too large")
-const defaultUpstreamResponseReadMaxBytes int64 = 8 * 1024 * 1024
+// defaultUpstreamResponseReadMaxBytes 源自 config.DefaultUpstreamResponseReadMaxBytes,
+// 仅在 cfg 为 nil 时作为兜底(测试或极端场景)。
+const defaultUpstreamResponseReadMaxBytes = config.DefaultUpstreamResponseReadMaxBytes
func resolveUpstreamResponseReadLimit(cfg *config.Config) int64 {
if cfg != nil && cfg.Gateway.UpstreamResponseReadMaxBytes > 0 {
--
GitLab
From 61a008f7e4e0e019f00e63409389aa5aa1058b0a Mon Sep 17 00:00:00 2001
From: erio
Date: Fri, 17 Apr 2026 23:05:58 +0800
Subject: [PATCH 017/261] chore(payment): mark legacy AES ciphertext fallback
as deprecated
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
明文 JSON 已经是新写入的默认格式;保留 AES 密文读取仅为兼容迁移期间的旧
记录,一旦所有部署通过管理后台重存过一次即可删除。标记为 deprecated 并加
TODO,几个版本后统一清理掉:payment.Encrypt / payment.Decrypt、两处
decryptConfig 的 AES 分支、PaymentConfigService.encryptionKey 和
DefaultLoadBalancer.encryptionKey 字段。
---
backend/internal/payment/crypto.go | 10 ++++++++++
backend/internal/payment/load_balancer.go | 7 +++++++
backend/internal/service/payment_config_providers.go | 6 ++++++
3 files changed, 23 insertions(+)
diff --git a/backend/internal/payment/crypto.go b/backend/internal/payment/crypto.go
index 5467e50b..0581469d 100644
--- a/backend/internal/payment/crypto.go
+++ b/backend/internal/payment/crypto.go
@@ -16,6 +16,11 @@ const AES256KeySize = 32
// Encrypt encrypts plaintext using AES-256-GCM with the given 32-byte key.
// The output format is "iv:authTag:ciphertext" where each component is base64-encoded,
// matching the Node.js crypto.ts format for cross-compatibility.
+//
+// Deprecated: payment provider configs are now stored as plaintext JSON.
+// This function is kept only for seeding legacy ciphertext in tests and for
+// the transitional Decrypt fallback. Scheduled for removal after all live
+// deployments complete migration by re-saving their configs.
func Encrypt(plaintext string, key []byte) (string, error) {
if len(key) != AES256KeySize {
return "", fmt.Errorf("encryption key must be %d bytes, got %d", AES256KeySize, len(key))
@@ -54,6 +59,11 @@ func Encrypt(plaintext string, key []byte) (string, error) {
// Decrypt decrypts a ciphertext string produced by Encrypt.
// The input format is "iv:authTag:ciphertext" where each component is base64-encoded.
+//
+// Deprecated: payment provider configs are now stored as plaintext JSON.
+// This function remains only as a read-path fallback for pre-migration
+// ciphertext records. Scheduled for removal once all deployments re-save
+// their provider configs through the admin UI.
func Decrypt(ciphertext string, key []byte) (string, error) {
if len(key) != AES256KeySize {
return "", fmt.Errorf("encryption key must be %d bytes, got %d", AES256KeySize, len(key))
diff --git a/backend/internal/payment/load_balancer.go b/backend/internal/payment/load_balancer.go
index 52a1b011..ec244cd6 100644
--- a/backend/internal/payment/load_balancer.go
+++ b/backend/internal/payment/load_balancer.go
@@ -283,6 +283,11 @@ func (lb *DefaultLoadBalancer) buildSelection(selected *dbent.PaymentProviderIns
// Unreadable values (legacy ciphertext without a valid key, or malformed data)
// are treated as empty so the service keeps running while the admin re-enters
// the config via the UI.
+//
+// TODO(deprecated-legacy-ciphertext): The AES fallback branch below is a
+// transitional compatibility shim for pre-plaintext records. Remove it (and
+// the encryptionKey field + the Decrypt import) after a few releases once all
+// live deployments have re-saved their provider configs through the UI.
func (lb *DefaultLoadBalancer) decryptConfig(stored string) (map[string]string, error) {
if stored == "" {
return nil, nil
@@ -291,7 +296,9 @@ func (lb *DefaultLoadBalancer) decryptConfig(stored string) (map[string]string,
if err := json.Unmarshal([]byte(stored), &config); err == nil {
return config, nil
}
+ // Deprecated: legacy AES-256-GCM ciphertext fallback — scheduled for removal.
if len(lb.encryptionKey) == AES256KeySize {
+ //nolint:staticcheck // SA1019: intentional legacy fallback, scheduled for removal
if plaintext, err := Decrypt(stored, lb.encryptionKey); err == nil {
if err := json.Unmarshal([]byte(plaintext), &config); err == nil {
return config, nil
diff --git a/backend/internal/service/payment_config_providers.go b/backend/internal/service/payment_config_providers.go
index 59337ad6..8e470525 100644
--- a/backend/internal/service/payment_config_providers.go
+++ b/backend/internal/service/payment_config_providers.go
@@ -296,6 +296,10 @@ func (s *PaymentConfigService) mergeConfig(ctx context.Context, id int64, newCon
// ("iv:authTag:ciphertext"). Values that cannot be parsed as either — including
// legacy ciphertext with no/invalid TOTP_ENCRYPTION_KEY — are treated as empty,
// letting the admin re-enter the config via the UI to complete the migration.
+//
+// TODO(deprecated-legacy-ciphertext): The AES fallback branch is a transitional
+// shim for pre-plaintext records. Remove it (and the encryptionKey field) after
+// a few releases once all live deployments have re-saved their provider configs.
func (s *PaymentConfigService) decryptConfig(stored string) (map[string]string, error) {
if stored == "" {
return nil, nil
@@ -304,7 +308,9 @@ func (s *PaymentConfigService) decryptConfig(stored string) (map[string]string,
if err := json.Unmarshal([]byte(stored), &cfg); err == nil {
return cfg, nil
}
+ // Deprecated: legacy AES-256-GCM ciphertext fallback — scheduled for removal.
if len(s.encryptionKey) == payment.AES256KeySize {
+ //nolint:staticcheck // SA1019: intentional legacy fallback, scheduled for removal
if plaintext, err := payment.Decrypt(stored, s.encryptionKey); err == nil {
if err := json.Unmarshal([]byte(plaintext), &cfg); err == nil {
return cfg, nil
--
GitLab
From 37123cef8fc14ca3d4b336a709167075ff8ea0df Mon Sep 17 00:00:00 2001
From: erio
Date: Sat, 18 Apr 2026 14:42:55 +0800
Subject: [PATCH 018/261] docs(payment): add Kyren Topup as international
EasyPay provider option
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Restructure the EasyPay recommendation block to present two options side
by side so users can pick by funding channel and settlement currency:
- Domestic / CNY — ZPay: official Alipay/WeChat API with 1.6% fee and
T+1 automatic settlement (existing recommendation, expanded with fee
and settlement details).
- International / USDT or USD — Kyren Topup (https://kyren.top): global
payment stack supporting WeChat Pay and Alipay with local-currency
checkout, USD settlement, and USDT/USD withdrawal. Fees: WeChat 2%,
Alipay 2.5%, withdrawal 0.1% ($40 min / $150 max). Fills the gap for
users who cannot use domestic Chinese channels or tolerate Stripe's
6%+ fees.
Both recommendations share a single disclaimer at the end. The Chinese
heading uses "易支付" while the English one keeps "EasyPay".
---
docs/PAYMENT.md | 7 ++++++-
docs/PAYMENT_CN.md | 7 ++++++-
2 files changed, 12 insertions(+), 2 deletions(-)
diff --git a/docs/PAYMENT.md b/docs/PAYMENT.md
index b66a791c..2735aea3 100644
--- a/docs/PAYMENT.md
+++ b/docs/PAYMENT.md
@@ -28,7 +28,12 @@ Sub2API has a built-in payment system that enables user self-service top-up with
> Alipay/WeChat Pay direct and EasyPay can coexist. Direct channels connect to payment APIs directly with lower fees; EasyPay aggregates through third-party platforms with easier setup.
-> **EasyPay Recommendation**: [ZPay](https://z-pay.cn/?uid=23808) (`https://z-pay.cn/?uid=23808`) is recommended as an EasyPay provider (link contains the referral code of [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) original author [@touwaeriol](https://github.com/touwaeriol) — feel free to remove it). ZPay supports **individual users** (no business license required) with up to 10,000 CNY daily transactions; business-licensed accounts have no limit. Please evaluate the security, reliability, and compliance of any third-party payment provider on your own — this project does not endorse or guarantee any of them.
+> **EasyPay Provider Recommendations**: Both options below are third-party aggregators compatible with the EasyPay protocol. Pick based on the funding channel and settlement currency you need:
+>
+> - **Domestic channel / CNY settlement** — [ZPay](https://z-pay.cn/?uid=23808) (`https://z-pay.cn/?uid=23808`): direct integration with official Alipay / WeChat Pay APIs, fee **1.6%**; funds go straight to the merchant account with **T+1 automatic settlement**. Supports **individual users** (no business license required) with up to 10,000 CNY daily transactions; business-licensed accounts have no limit. Link contains the referral code of [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) original author [@touwaeriol](https://github.com/touwaeriol) — feel free to remove it.
+> - **International channel / USDT or USD settlement** — [Kyren Topup](https://kyren.top/?code=SUB2API) (`https://kyren.top/?code=SUB2API`): a ready-to-launch global payment stack for AI startups with WeChat Pay and Alipay support, local-currency checkout, and USD settlement. Fees: WeChat 2%, Alipay 2.5%; withdrawal 0.1% (min $40, max $150), settled in **USDT or USD**. No qualification review required — sign up and use immediately, making it the lowest barrier to entry. Withdrawal threshold is relatively high, recommended for users **who do not use domestic Chinese payment channels, cannot tolerate Stripe's 6%+ fees, have high transaction volume, and have USD or USDT channels to receive withdrawn funds**. Link contains Sub2Api author [@Wei-Shaw](https://github.com/Wei-Shaw)'s referral code — feel free to remove it.
+>
+> Please evaluate the security, reliability, and compliance of any third-party payment provider on your own — this project does not endorse or guarantee any of them.
---
diff --git a/docs/PAYMENT_CN.md b/docs/PAYMENT_CN.md
index 9d96557f..474700e5 100644
--- a/docs/PAYMENT_CN.md
+++ b/docs/PAYMENT_CN.md
@@ -28,7 +28,12 @@ Sub2API 内置支付系统,支持用户自助充值,无需部署独立的支
> 支付宝官方 / 微信官方与 EasyPay 可以共存。官方渠道直接对接 API,资金直达商户账户,手续费更低;EasyPay 通过第三方平台聚合,接入门槛更低。
-> **EasyPay 推荐**:个人推荐 [ZPay](https://z-pay.cn/?uid=23808)(`https://z-pay.cn/?uid=23808`)作为 EasyPay 服务商(链接含 [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) 原作者 [@touwaeriol](https://github.com/touwaeriol) 的邀请码,介意可去掉)。ZPay 支持**个人用户**(无营业执照)每日 1 万元以内交易;拥有营业执照则无限额。支付渠道的安全性、稳定性及合规性请自行鉴别,本项目不对任何第三方支付服务商做担保或背书。
+> **易支付服务商推荐**:以下两家均为兼容易支付协议的第三方聚合支付,按资金通道与结算方式选择:
+>
+> - **国内渠道 / 人民币结算** — [ZPay](https://z-pay.cn/?uid=23808)(`https://z-pay.cn/?uid=23808`):支付宝 / 微信官方 API 直连,手续费 **1.6%**;资金直达商家账户,**T+1 自动到账**。支持**个人用户**(无营业执照)每日 1 万元以内交易;拥有营业执照则无限额。链接含 [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) 原作者 [@touwaeriol](https://github.com/touwaeriol) 的邀请码,介意可去掉。
+> - **国际渠道 / USDT 或美元结算** — [启润支付](https://kyren.top/?code=SUB2API)(`https://kyren.top/?code=SUB2API`):为 AI 项目提供低门槛国际收款通道,支持国际版微信支付与支付宝,本地货币支付、美元结算。手续费:微信 2%、支付宝 2.5%;提现 0.1%(最低 40 美元、最高 150 美元),以 **USDT 或美元**到账。无资质审核、注册即用,使用门槛最低;提现门槛略高,适合**不使用国内支付渠道、无法接受 Stripe 高达 6%+ 手续费、流水较大,且拥有美元或 USDT 渠道可接收提现资金**的用户。链接含 Sub2Api 作者 [@Wei-Shaw](https://github.com/Wei-Shaw) 邀请码,介意可去掉。
+>
+> 支付渠道的安全性、稳定性及合规性请自行鉴别,本项目不对任何第三方支付服务商做担保或背书。
---
--
GitLab
From 6ae1cc8f3f893ea1c6896ee045c8e26ff4b2f16d Mon Sep 17 00:00:00 2001
From: erio
Date: Sat, 18 Apr 2026 14:45:25 +0800
Subject: [PATCH 019/261] =?UTF-8?q?docs:=20use=20=E6=98=93=E6=94=AF?=
=?UTF-8?q?=E4=BB=98=20in=20Chinese=20coexistence=20note?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
docs/PAYMENT_CN.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/docs/PAYMENT_CN.md b/docs/PAYMENT_CN.md
index 474700e5..11325a80 100644
--- a/docs/PAYMENT_CN.md
+++ b/docs/PAYMENT_CN.md
@@ -26,7 +26,7 @@ Sub2API 内置支付系统,支持用户自助充值,无需部署独立的支
| **微信官方** | Native 扫码支付、H5 支付 | 直接对接微信支付 APIv3,移动端优先 H5 |
| **Stripe** | 银行卡、支付宝、微信支付、Link 等 | 国际支付,支持多币种 |
-> 支付宝官方 / 微信官方与 EasyPay 可以共存。官方渠道直接对接 API,资金直达商户账户,手续费更低;EasyPay 通过第三方平台聚合,接入门槛更低。
+> 支付宝官方 / 微信官方与易支付可以共存。官方渠道直接对接 API,资金直达商户账户,手续费更低;易支付通过第三方平台聚合,接入门槛更低。
> **易支付服务商推荐**:以下两家均为兼容易支付协议的第三方聚合支付,按资金通道与结算方式选择:
>
--
GitLab
From 0c538a584fbf48414fe3c11b9a3c4ddb2dad96d5 Mon Sep 17 00:00:00 2001
From: erio
Date: Sat, 18 Apr 2026 14:48:42 +0800
Subject: [PATCH 020/261] docs: note Kyren Topup $200 account fee waived via
referral link
---
docs/PAYMENT.md | 2 +-
docs/PAYMENT_CN.md | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/docs/PAYMENT.md b/docs/PAYMENT.md
index 2735aea3..755b313a 100644
--- a/docs/PAYMENT.md
+++ b/docs/PAYMENT.md
@@ -31,7 +31,7 @@ Sub2API has a built-in payment system that enables user self-service top-up with
> **EasyPay Provider Recommendations**: Both options below are third-party aggregators compatible with the EasyPay protocol. Pick based on the funding channel and settlement currency you need:
>
> - **Domestic channel / CNY settlement** — [ZPay](https://z-pay.cn/?uid=23808) (`https://z-pay.cn/?uid=23808`): direct integration with official Alipay / WeChat Pay APIs, fee **1.6%**; funds go straight to the merchant account with **T+1 automatic settlement**. Supports **individual users** (no business license required) with up to 10,000 CNY daily transactions; business-licensed accounts have no limit. Link contains the referral code of [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) original author [@touwaeriol](https://github.com/touwaeriol) — feel free to remove it.
-> - **International channel / USDT or USD settlement** — [Kyren Topup](https://kyren.top/?code=SUB2API) (`https://kyren.top/?code=SUB2API`): a ready-to-launch global payment stack for AI startups with WeChat Pay and Alipay support, local-currency checkout, and USD settlement. Fees: WeChat 2%, Alipay 2.5%; withdrawal 0.1% (min $40, max $150), settled in **USDT or USD**. No qualification review required — sign up and use immediately, making it the lowest barrier to entry. Withdrawal threshold is relatively high, recommended for users **who do not use domestic Chinese payment channels, cannot tolerate Stripe's 6%+ fees, have high transaction volume, and have USD or USDT channels to receive withdrawn funds**. Link contains Sub2Api author [@Wei-Shaw](https://github.com/Wei-Shaw)'s referral code — feel free to remove it.
+> - **International channel / USDT or USD settlement** — [Kyren Topup](https://kyren.top/?code=SUB2API) (`https://kyren.top/?code=SUB2API`): a ready-to-launch global payment stack for AI startups with WeChat Pay and Alipay support, local-currency checkout, and USD settlement. Fees: WeChat 2%, Alipay 2.5%; withdrawal 0.1% (min $40, max $150), settled in **USDT or USD**. No qualification review required — sign up and use immediately, making it the lowest barrier to entry. Withdrawal threshold is relatively high, recommended for users **who do not use domestic Chinese payment channels, cannot tolerate Stripe's 6%+ fees, have high transaction volume, and have USD or USDT channels to receive withdrawn funds**. Kyren Topup charges a $200 account opening fee; signing up via this link (which contains Sub2Api author [@Wei-Shaw](https://github.com/Wei-Shaw)'s referral code) **waives the opening fee**. Feel free to remove it if you prefer.
>
> Please evaluate the security, reliability, and compliance of any third-party payment provider on your own — this project does not endorse or guarantee any of them.
diff --git a/docs/PAYMENT_CN.md b/docs/PAYMENT_CN.md
index 11325a80..aca3c866 100644
--- a/docs/PAYMENT_CN.md
+++ b/docs/PAYMENT_CN.md
@@ -31,7 +31,7 @@ Sub2API 内置支付系统,支持用户自助充值,无需部署独立的支
> **易支付服务商推荐**:以下两家均为兼容易支付协议的第三方聚合支付,按资金通道与结算方式选择:
>
> - **国内渠道 / 人民币结算** — [ZPay](https://z-pay.cn/?uid=23808)(`https://z-pay.cn/?uid=23808`):支付宝 / 微信官方 API 直连,手续费 **1.6%**;资金直达商家账户,**T+1 自动到账**。支持**个人用户**(无营业执照)每日 1 万元以内交易;拥有营业执照则无限额。链接含 [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) 原作者 [@touwaeriol](https://github.com/touwaeriol) 的邀请码,介意可去掉。
-> - **国际渠道 / USDT 或美元结算** — [启润支付](https://kyren.top/?code=SUB2API)(`https://kyren.top/?code=SUB2API`):为 AI 项目提供低门槛国际收款通道,支持国际版微信支付与支付宝,本地货币支付、美元结算。手续费:微信 2%、支付宝 2.5%;提现 0.1%(最低 40 美元、最高 150 美元),以 **USDT 或美元**到账。无资质审核、注册即用,使用门槛最低;提现门槛略高,适合**不使用国内支付渠道、无法接受 Stripe 高达 6%+ 手续费、流水较大,且拥有美元或 USDT 渠道可接收提现资金**的用户。链接含 Sub2Api 作者 [@Wei-Shaw](https://github.com/Wei-Shaw) 邀请码,介意可去掉。
+> - **国际渠道 / USDT 或美元结算** — [启润支付](https://kyren.top/?code=SUB2API)(`https://kyren.top/?code=SUB2API`):为 AI 项目提供低门槛国际收款通道,支持国际版微信支付与支付宝,本地货币支付、美元结算。手续费:微信 2%、支付宝 2.5%;提现 0.1%(最低 40 美元、最高 150 美元),以 **USDT 或美元**到账。无资质审核、注册即用,使用门槛最低;提现门槛略高,适合**不使用国内支付渠道、无法接受 Stripe 高达 6%+ 手续费、流水较大,且拥有美元或 USDT 渠道可接收提现资金**的用户。启润支付开户费 200 美元,通过本链接注册(含 Sub2Api 作者 [@Wei-Shaw](https://github.com/Wei-Shaw) 邀请码)可**免开户费**,介意可去掉。
>
> 支付渠道的安全性、稳定性及合规性请自行鉴别,本项目不对任何第三方支付服务商做担保或背书。
--
GitLab
From c3cb0280ef8dff949ad7ad8f84793872b9304686 Mon Sep 17 00:00:00 2001
From: erio
Date: Sun, 19 Apr 2026 01:40:25 +0800
Subject: [PATCH 021/261] fix(payment): alipay redirect-only flow, H5 detection
and popup sizing
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
The native Alipay provider previously tried to embed the payment page
URL into a QR code on the client — the URL is not a scannable payload
so the QR never worked. Merchants also hit a H5 detection mismatch
whenever the backend UA sniffer missed iPadOS 13+ or embedded browsers,
and the popup window was too small for Alipay's standard checkout
layout (QR + account-login panel on the right), forcing the user to
scroll horizontally and vertically.
Changes:
Backend
- alipay.go: drop QR-on-URL path. Use redirect-only flow —
alipay.trade.page.pay for PC (returns a gateway URL the browser
opens in a new window) and alipay.trade.wap.pay for H5 (returns a
URL the browser jumps to). Both flows produce pages on
openapi.alipaydev.com / excashier.alipay.com; the client never
renders a QR itself.
- payment_handler.go: add optional is_mobile bool to
CreateOrderRequest so the frontend can declare the device
explicitly. Server still falls back to UA sniffing when absent.
Frontend
- types/payment.ts, PaymentView.vue: declare is_mobile in
CreateOrderRequest and pass the computed isMobileDevice() value.
- providerConfig.ts: replace the two fixed POPUP_WINDOW_FEATURES
constants with getPaymentPopupFeatures(), which prefers 1250×900
(Alipay's checkout footprint), clamps to window.screen.avail* and
centers the popup so it never overflows on smaller laptops.
- PaymentQRDialog.vue, PaymentStatusPanel.vue, StripePaymentInline.vue,
PaymentView.vue: use the new helper at all popup call sites.
---
backend/internal/handler/payment_handler.go | 10 +++-
backend/internal/payment/provider/alipay.go | 50 ++++++++++---------
.../components/payment/PaymentQRDialog.vue | 4 +-
.../components/payment/PaymentStatusPanel.vue | 4 +-
.../payment/StripePaymentInline.vue | 4 +-
.../src/components/payment/providerConfig.ts | 23 +++++++--
frontend/src/types/payment.ts | 1 +
frontend/src/views/user/PaymentView.vue | 5 +-
8 files changed, 64 insertions(+), 37 deletions(-)
diff --git a/backend/internal/handler/payment_handler.go b/backend/internal/handler/payment_handler.go
index 1ddb8ae2..854dca54 100644
--- a/backend/internal/handler/payment_handler.go
+++ b/backend/internal/handler/payment_handler.go
@@ -206,6 +206,10 @@ type CreateOrderRequest struct {
PaymentType string `json:"payment_type" binding:"required"`
OrderType string `json:"order_type"`
PlanID int64 `json:"plan_id"`
+ // IsMobile lets the frontend declare its mobile status directly. When
+ // nil we fall back to User-Agent heuristics (which miss iPadOS / some
+ // embedded browsers that strip the "Mobile" keyword).
+ IsMobile *bool `json:"is_mobile,omitempty"`
}
// CreateOrder creates a new payment order.
@@ -222,12 +226,16 @@ func (h *PaymentHandler) CreateOrder(c *gin.Context) {
return
}
+ mobile := isMobile(c)
+ if req.IsMobile != nil {
+ mobile = *req.IsMobile
+ }
result, err := h.paymentService.CreateOrder(c.Request.Context(), service.CreateOrderRequest{
UserID: subject.UserID,
Amount: req.Amount,
PaymentType: req.PaymentType,
ClientIP: c.ClientIP(),
- IsMobile: isMobile(c),
+ IsMobile: mobile,
SrcHost: c.Request.Host,
SrcURL: c.Request.Referer(),
OrderType: req.OrderType,
diff --git a/backend/internal/payment/provider/alipay.go b/backend/internal/payment/provider/alipay.go
index af8a90c6..fe8ea89c 100644
--- a/backend/internal/payment/provider/alipay.go
+++ b/backend/internal/payment/provider/alipay.go
@@ -15,8 +15,8 @@ import (
// Alipay product codes.
const (
- alipayProductCodePagePay = "FAST_INSTANT_TRADE_PAY"
alipayProductCodeWapPay = "QUICK_WAP_WAY"
+ alipayProductCodePagePay = "FAST_INSTANT_TRADE_PAY"
)
// Alipay response constants.
@@ -79,7 +79,12 @@ func (a *Alipay) SupportedTypes() []payment.PaymentType {
return []payment.PaymentType{payment.TypeAlipay}
}
-// CreatePayment creates an Alipay payment page URL.
+// CreatePayment creates an Alipay payment using redirect-only flow:
+// - Mobile (H5): alipay.trade.wap.pay — returns a URL the browser jumps to.
+// - PC: alipay.trade.page.pay — returns a gateway URL the browser opens in a
+// new window; Alipay's own page then shows login/QR. We intentionally do
+// NOT encode the URL into a QR on the client (it isn't a scannable payload
+// and would produce an invalid scan result).
func (a *Alipay) CreatePayment(_ context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
client, err := a.getClient()
if err != nil {
@@ -96,31 +101,31 @@ func (a *Alipay) CreatePayment(_ context.Context, req payment.CreatePaymentReque
}
if req.IsMobile {
- return a.createTrade(client, req, notifyURL, returnURL, true)
+ return a.createWapTrade(client, req, notifyURL, returnURL)
}
- return a.createTrade(client, req, notifyURL, returnURL, false)
+ return a.createPagePayTrade(client, req, notifyURL, returnURL)
}
-func (a *Alipay) createTrade(client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string, isMobile bool) (*payment.CreatePaymentResponse, error) {
- if isMobile {
- param := alipay.TradeWapPay{}
- param.OutTradeNo = req.OrderID
- param.TotalAmount = req.Amount
- param.Subject = req.Subject
- param.ProductCode = alipayProductCodeWapPay
- param.NotifyURL = notifyURL
- param.ReturnURL = returnURL
-
- payURL, err := client.TradeWapPay(param)
- if err != nil {
- return nil, fmt.Errorf("alipay TradeWapPay: %w", err)
- }
- return &payment.CreatePaymentResponse{
- TradeNo: req.OrderID,
- PayURL: payURL.String(),
- }, nil
+func (a *Alipay) createWapTrade(client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string) (*payment.CreatePaymentResponse, error) {
+ param := alipay.TradeWapPay{}
+ param.OutTradeNo = req.OrderID
+ param.TotalAmount = req.Amount
+ param.Subject = req.Subject
+ param.ProductCode = alipayProductCodeWapPay
+ param.NotifyURL = notifyURL
+ param.ReturnURL = returnURL
+
+ payURL, err := client.TradeWapPay(param)
+ if err != nil {
+ return nil, fmt.Errorf("alipay TradeWapPay: %w", err)
}
+ return &payment.CreatePaymentResponse{
+ TradeNo: req.OrderID,
+ PayURL: payURL.String(),
+ }, nil
+}
+func (a *Alipay) createPagePayTrade(client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string) (*payment.CreatePaymentResponse, error) {
param := alipay.TradePagePay{}
param.OutTradeNo = req.OrderID
param.TotalAmount = req.Amount
@@ -136,7 +141,6 @@ func (a *Alipay) createTrade(client *alipay.Client, req payment.CreatePaymentReq
return &payment.CreatePaymentResponse{
TradeNo: req.OrderID,
PayURL: payURL.String(),
- QRCode: payURL.String(),
}, nil
}
diff --git a/frontend/src/components/payment/PaymentQRDialog.vue b/frontend/src/components/payment/PaymentQRDialog.vue
index b9026e78..db90c3b6 100644
--- a/frontend/src/components/payment/PaymentQRDialog.vue
+++ b/frontend/src/components/payment/PaymentQRDialog.vue
@@ -79,7 +79,7 @@ import { usePaymentStore } from '@/stores/payment'
import { useAppStore } from '@/stores'
import { paymentAPI } from '@/api/payment'
import { extractApiErrorMessage } from '@/utils/apiError'
-import { POPUP_WINDOW_FEATURES } from '@/components/payment/providerConfig'
+import { getPaymentPopupFeatures } from '@/components/payment/providerConfig'
import type { PaymentOrder } from '@/types/payment'
import QRCode from 'qrcode'
import alipayIcon from '@/assets/icons/alipay.svg'
@@ -147,7 +147,7 @@ function getLogoForType(): string | null {
function reopenPopup() {
if (props.payUrl) {
- window.open(props.payUrl, 'paymentPopup', POPUP_WINDOW_FEATURES)
+ window.open(props.payUrl, 'paymentPopup', getPaymentPopupFeatures())
}
}
diff --git a/frontend/src/components/payment/PaymentStatusPanel.vue b/frontend/src/components/payment/PaymentStatusPanel.vue
index 974dee66..17541e59 100644
--- a/frontend/src/components/payment/PaymentStatusPanel.vue
+++ b/frontend/src/components/payment/PaymentStatusPanel.vue
@@ -125,7 +125,7 @@ import { usePaymentStore } from '@/stores/payment'
import { useAppStore } from '@/stores'
import { paymentAPI } from '@/api/payment'
import { extractApiErrorMessage } from '@/utils/apiError'
-import { POPUP_WINDOW_FEATURES } from '@/components/payment/providerConfig'
+import { getPaymentPopupFeatures } from '@/components/payment/providerConfig'
import type { PaymentOrder } from '@/types/payment'
import Icon from '@/components/icons/Icon.vue'
import QRCode from 'qrcode'
@@ -194,7 +194,7 @@ const countdownDisplay = computed(() => {
function reopenPopup() {
if (props.payUrl) {
- window.open(props.payUrl, 'paymentPopup', POPUP_WINDOW_FEATURES)
+ window.open(props.payUrl, 'paymentPopup', getPaymentPopupFeatures())
}
}
diff --git a/frontend/src/components/payment/StripePaymentInline.vue b/frontend/src/components/payment/StripePaymentInline.vue
index b8fd55ef..3ddff8c8 100644
--- a/frontend/src/components/payment/StripePaymentInline.vue
+++ b/frontend/src/components/payment/StripePaymentInline.vue
@@ -70,7 +70,7 @@ import { useRouter } from 'vue-router'
import { extractApiErrorMessage } from '@/utils/apiError'
import { paymentAPI } from '@/api/payment'
import { useAppStore } from '@/stores'
-import { STRIPE_POPUP_WINDOW_FEATURES } from '@/components/payment/providerConfig'
+import { getPaymentPopupFeatures } from '@/components/payment/providerConfig'
import type { Stripe, StripeElements } from '@stripe/stripe-js'
import Icon from '@/components/icons/Icon.vue'
@@ -151,7 +151,7 @@ async function handlePay() {
amount: String(props.payAmount),
},
}).href
- const popup = window.open(popupUrl, 'paymentPopup', STRIPE_POPUP_WINDOW_FEATURES)
+ const popup = window.open(popupUrl, 'paymentPopup', getPaymentPopupFeatures())
const onReady = (event: MessageEvent) => {
if (event.source !== popup || event.data?.type !== 'STRIPE_POPUP_READY') return
diff --git a/frontend/src/components/payment/providerConfig.ts b/frontend/src/components/payment/providerConfig.ts
index a83787fd..bf2d4177 100644
--- a/frontend/src/components/payment/providerConfig.ts
+++ b/frontend/src/components/payment/providerConfig.ts
@@ -43,11 +43,24 @@ export const METHOD_ORDER = ['alipay', 'alipay_direct', 'wxpay', 'wxpay_direct',
export const PAYMENT_MODE_QRCODE = 'qrcode'
export const PAYMENT_MODE_POPUP = 'popup'
-/** Window features for payment popup windows */
-export const POPUP_WINDOW_FEATURES = 'width=1000,height=750,left=100,top=80,scrollbars=yes,resizable=yes'
-
-/** Wider popup for Stripe redirect methods (Alipay checkout page needs ~1200px) */
-export const STRIPE_POPUP_WINDOW_FEATURES = 'width=1250,height=780,left=80,top=60,scrollbars=yes,resizable=yes'
+/** Preferred popup size for payment gateways. Alipay's standard checkout
+ * (QR + account login panel) needs ~1200×900 to render without any scrolling. */
+const PAYMENT_POPUP_PREFERRED_WIDTH = 1250
+const PAYMENT_POPUP_PREFERRED_HEIGHT = 900
+
+/** Build a window.open features string sized to fit within the current screen
+ * while preferring the above dimensions. Centers the popup on the available
+ * work area so nothing is clipped on smaller laptop displays. */
+export function getPaymentPopupFeatures(): string {
+ const screen = typeof window !== 'undefined' ? window.screen : null
+ const availW = screen?.availWidth ?? PAYMENT_POPUP_PREFERRED_WIDTH
+ const availH = screen?.availHeight ?? PAYMENT_POPUP_PREFERRED_HEIGHT
+ const width = Math.min(PAYMENT_POPUP_PREFERRED_WIDTH, availW - 40)
+ const height = Math.min(PAYMENT_POPUP_PREFERRED_HEIGHT, availH - 40)
+ const left = Math.max(0, Math.floor((availW - width) / 2))
+ const top = Math.max(0, Math.floor((availH - height) / 2))
+ return `width=${width},height=${height},left=${left},top=${top},scrollbars=yes,resizable=yes`
+}
/** Webhook paths for each provider (relative to origin). */
export const WEBHOOK_PATHS: Record = {
diff --git a/frontend/src/types/payment.ts b/frontend/src/types/payment.ts
index 7ecbb9a9..6f2eec51 100644
--- a/frontend/src/types/payment.ts
+++ b/frontend/src/types/payment.ts
@@ -154,6 +154,7 @@ export interface CreateOrderRequest {
payment_type: string
order_type: string
plan_id?: number
+ is_mobile?: boolean
}
export interface CreateOrderResult {
diff --git a/frontend/src/views/user/PaymentView.vue b/frontend/src/views/user/PaymentView.vue
index e91df5da..3f1401b3 100644
--- a/frontend/src/views/user/PaymentView.vue
+++ b/frontend/src/views/user/PaymentView.vue
@@ -277,7 +277,7 @@ import type { SubscriptionPlan, CheckoutInfoResponse, OrderType } from '@/types/
import AppLayout from '@/components/layout/AppLayout.vue'
import AmountInput from '@/components/payment/AmountInput.vue'
import PaymentMethodSelector from '@/components/payment/PaymentMethodSelector.vue'
-import { METHOD_ORDER, POPUP_WINDOW_FEATURES } from '@/components/payment/providerConfig'
+import { METHOD_ORDER, getPaymentPopupFeatures } from '@/components/payment/providerConfig'
import { platformAccentBarClass, platformBadgeLightClass, platformBadgeClass, platformTextClass, platformLabel } from '@/utils/platformColors'
import SubscriptionPlanCard from '@/components/payment/SubscriptionPlanCard.vue'
import PaymentStatusPanel from '@/components/payment/PaymentStatusPanel.vue'
@@ -551,9 +551,10 @@ async function createOrder(orderAmount: number, orderType: OrderType, planId?: n
payment_type: selectedMethod.value,
order_type: orderType,
plan_id: planId,
+ is_mobile: isMobileDevice(),
})
const openWindow = (url: string) => {
- const win = window.open(url, 'paymentPopup', POPUP_WINDOW_FEATURES)
+ const win = window.open(url, 'paymentPopup', getPaymentPopupFeatures())
if (!win || win.closed) {
window.location.href = url
}
--
GitLab
From 235f710853b3e44b5b2fa236cfefd0b9d0c2a044 Mon Sep 17 00:00:00 2001
From: erio
Date: Sun, 19 Apr 2026 01:46:50 +0800
Subject: [PATCH 022/261] feat(payment): redact provider secrets in admin
config API
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Admin GET /api/v1/admin/payment/providers previously returned every
config value — including privateKey / apiV3Key / secretKey etc. —
verbatim. Any future XSS on the admin UI would hand attackers the
full set of production payment credentials, and the plaintext values
sat unnecessarily in browser memory for every operator.
Treat those fields as write-only from the admin surface:
- decryptAndMaskConfig() strips sensitive keys from the GET response.
The authoritative list is an explicit per-provider registry that
mirrors the frontend's PROVIDER_CONFIG_FIELDS sensitive flag:
alipay → privateKey, publicKey, alipayPublicKey
wxpay → privateKey, apiV3Key, publicKey
stripe → secretKey, webhookSecret (publishableKey stays plain)
easypay → pkey
Payment runtime still reads the full config via decryptConfig, so
nothing at the gateway changes.
- mergeConfig() treats an empty value for a sensitive key as "leave
unchanged" — the admin UI omits unchanged secrets so operators can
tweak non-sensitive settings without re-entering credentials.
- Admin dialog (PaymentProviderDialog.vue):
* secret inputs get autocomplete="new-password", data-1p-ignore,
data-lpignore and data-bwignore so password managers do not
offer to save provider credentials
* in edit mode the required-field check skips sensitive fields
(empty is the "keep existing" signal) and the placeholder shows
"leave empty to keep" instead of the default example value
* create mode still requires every non-optional field, including
secrets, since there is nothing to preserve
- Unit test renamed to TestIsSensitiveProviderConfigField, covers
the per-provider registry and specifically asserts that Stripe's
publishableKey is NOT treated as a secret.
---
.../service/payment_config_providers.go | 78 +++++++++++++++----
.../service/payment_config_providers_test.go | 63 ++++++++-------
.../payment/PaymentProviderDialog.vue | 23 ++++--
3 files changed, 119 insertions(+), 45 deletions(-)
diff --git a/backend/internal/service/payment_config_providers.go b/backend/internal/service/payment_config_providers.go
index 8e470525..3d1e4dc4 100644
--- a/backend/internal/service/payment_config_providers.go
+++ b/backend/internal/service/payment_config_providers.go
@@ -52,7 +52,7 @@ func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Conte
AllowUserRefund: inst.AllowUserRefund,
SortOrder: inst.SortOrder, PaymentMode: inst.PaymentMode,
}
- resp.Config, err = s.decryptAndMaskConfig(inst.Config)
+ resp.Config, err = s.decryptAndMaskConfig(inst.ProviderKey, inst.Config)
if err != nil {
return nil, fmt.Errorf("decrypt config for instance %d: %w", inst.ID, err)
}
@@ -61,8 +61,26 @@ func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Conte
return result, nil
}
-func (s *PaymentConfigService) decryptAndMaskConfig(encrypted string) (map[string]string, error) {
- return s.decryptConfig(encrypted)
+// decryptAndMaskConfig returns the stored config with sensitive fields omitted.
+// Admin UIs display masked placeholders for these; the raw values never leave
+// the server. Callers that need the full config (e.g. payment runtime) must
+// use decryptConfig directly.
+func (s *PaymentConfigService) decryptAndMaskConfig(providerKey, encrypted string) (map[string]string, error) {
+ cfg, err := s.decryptConfig(encrypted)
+ if err != nil {
+ return nil, err
+ }
+ if cfg == nil {
+ return nil, nil
+ }
+ masked := make(map[string]string, len(cfg))
+ for k, v := range cfg {
+ if isSensitiveProviderConfigField(providerKey, k) {
+ continue
+ }
+ masked[k] = v
+ }
+ return masked, nil
}
// pendingOrderStatuses are order statuses considered "in progress".
@@ -72,16 +90,27 @@ var pendingOrderStatuses = []string{
payment.OrderStatusRecharging,
}
-var sensitiveConfigPatterns = []string{"key", "pkey", "secret", "private", "password"}
+// providerSensitiveConfigFields is the authoritative list of config keys that
+// are treated as secrets per provider. Must stay in sync with the frontend
+// definition at frontend/src/components/payment/providerConfig.ts
+// (PROVIDER_CONFIG_FIELDS, fields with sensitive: true).
+//
+// Key matching is case-insensitive. Non-listed keys (e.g. appId, notifyUrl,
+// stripe publishableKey) are returned in plaintext by the admin GET API.
+var providerSensitiveConfigFields = map[string]map[string]struct{}{
+ payment.TypeEasyPay: {"pkey": {}},
+ payment.TypeAlipay: {"privatekey": {}, "publickey": {}, "alipaypublickey": {}},
+ payment.TypeWxpay: {"privatekey": {}, "apiv3key": {}, "publickey": {}},
+ payment.TypeStripe: {"secretkey": {}, "webhooksecret": {}},
+}
-func isSensitiveConfigField(fieldName string) bool {
- lower := strings.ToLower(fieldName)
- for _, p := range sensitiveConfigPatterns {
- if strings.Contains(lower, p) {
- return true
- }
+func isSensitiveProviderConfigField(providerKey, fieldName string) bool {
+ fields, ok := providerSensitiveConfigFields[providerKey]
+ if !ok {
+ return false
}
- return false
+ _, found := fields[strings.ToLower(fieldName)]
+ return found
}
func (s *PaymentConfigService) countPendingOrders(ctx context.Context, providerInstanceID int64) (int, error) {
@@ -137,10 +166,26 @@ func validateProviderRequest(providerKey, name, supportedTypes string) error {
// NOTE: This function exceeds 30 lines due to per-field nil-check patch update
// boilerplate and pending-order safety checks.
func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id int64, req UpdateProviderInstanceRequest) (*dbent.PaymentProviderInstance, error) {
+ var cachedInst *dbent.PaymentProviderInstance
+ loadInst := func() (*dbent.PaymentProviderInstance, error) {
+ if cachedInst != nil {
+ return cachedInst, nil
+ }
+ inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id)
+ if err != nil {
+ return nil, fmt.Errorf("load provider instance: %w", err)
+ }
+ cachedInst = inst
+ return inst, nil
+ }
if req.Config != nil {
+ inst, err := loadInst()
+ if err != nil {
+ return nil, err
+ }
hasSensitive := false
- for k := range req.Config {
- if isSensitiveConfigField(k) && req.Config[k] != "" {
+ for k, v := range req.Config {
+ if v != "" && isSensitiveProviderConfigField(inst.ProviderKey, k) {
hasSensitive = true
break
}
@@ -283,9 +328,14 @@ func (s *PaymentConfigService) mergeConfig(ctx context.Context, id int64, newCon
return nil, fmt.Errorf("decrypt existing config for instance %d: %w", id, err)
}
if existing == nil {
- return newConfig, nil
+ existing = map[string]string{}
}
for k, v := range newConfig {
+ // Preserve existing secrets when the client submits an empty value
+ // (admin UI omits the value to indicate "leave unchanged").
+ if v == "" && isSensitiveProviderConfigField(inst.ProviderKey, k) {
+ continue
+ }
existing[k] = v
}
return existing, nil
diff --git a/backend/internal/service/payment_config_providers_test.go b/backend/internal/service/payment_config_providers_test.go
index 2aaa874f..bc2a9b18 100644
--- a/backend/internal/service/payment_config_providers_test.go
+++ b/backend/internal/service/payment_config_providers_test.go
@@ -97,41 +97,52 @@ func TestValidateProviderRequest(t *testing.T) {
}
}
-func TestIsSensitiveConfigField(t *testing.T) {
+func TestIsSensitiveProviderConfigField(t *testing.T) {
t.Parallel()
tests := []struct {
- field string
- wantSen bool
+ providerKey string
+ field string
+ wantSen bool
}{
- // Sensitive fields (contain key/secret/private/password/pkey patterns)
- {"secretKey", true},
- {"apiSecret", true},
- {"pkey", true},
- {"privateKey", true},
- {"apiPassword", true},
- {"appKey", true},
- {"SECRET_TOKEN", true},
- {"PrivateData", true},
- {"PASSWORD", true},
- {"mySecretValue", true},
-
- // Non-sensitive fields
- {"appId", false},
- {"mchId", false},
- {"apiBase", false},
- {"endpoint", false},
- {"merchantNo", false},
- {"paymentMode", false},
- {"notifyUrl", false},
+ // Stripe: publishableKey is public, only secretKey/webhookSecret are secrets
+ {"stripe", "secretKey", true},
+ {"stripe", "webhookSecret", true},
+ {"stripe", "SecretKey", true}, // case-insensitive
+ {"stripe", "publishableKey", false},
+ {"stripe", "appId", false},
+
+ // Alipay
+ {"alipay", "privateKey", true},
+ {"alipay", "publicKey", true},
+ {"alipay", "alipayPublicKey", true},
+ {"alipay", "appId", false},
+ {"alipay", "notifyUrl", false},
+
+ // Wxpay
+ {"wxpay", "privateKey", true},
+ {"wxpay", "apiV3Key", true},
+ {"wxpay", "publicKey", true},
+ {"wxpay", "publicKeyId", false},
+ {"wxpay", "certSerial", false},
+ {"wxpay", "mchId", false},
+
+ // EasyPay
+ {"easypay", "pkey", true},
+ {"easypay", "pid", false},
+ {"easypay", "apiBase", false},
+
+ // Unknown provider: never sensitive
+ {"unknown", "secretKey", false},
}
for _, tc := range tests {
- t.Run(tc.field, func(t *testing.T) {
+ tc := tc
+ t.Run(tc.providerKey+"/"+tc.field, func(t *testing.T) {
t.Parallel()
- got := isSensitiveConfigField(tc.field)
- assert.Equal(t, tc.wantSen, got, "isSensitiveConfigField(%q)", tc.field)
+ got := isSensitiveProviderConfigField(tc.providerKey, tc.field)
+ assert.Equal(t, tc.wantSen, got, "isSensitiveProviderConfigField(%q, %q)", tc.providerKey, tc.field)
})
}
}
diff --git a/frontend/src/components/payment/PaymentProviderDialog.vue b/frontend/src/components/payment/PaymentProviderDialog.vue
index 10c1bfea..624ddcdd 100644
--- a/frontend/src/components/payment/PaymentProviderDialog.vue
+++ b/frontend/src/components/payment/PaymentProviderDialog.vue
@@ -88,13 +88,24 @@
v-model="config[field.key]"
rows="3"
class="input font-mono text-xs"
+ autocomplete="new-password"
+ data-1p-ignore
+ data-lpignore="true"
+ data-bwignore="true"
+ spellcheck="false"
+ :placeholder="editing ? t('admin.accounts.leaveEmptyToKeep') : ''"
/>
= {}
for (const [k, v] of Object.entries(config)) {
if (!v || !v.trim()) continue
- // Skip masked values — backend keeps existing credentials
- if (v === '••••••••') continue
filteredConfig[k] = v
}
@@ -470,7 +482,8 @@ function loadProvider(provider: ProviderInstance) {
form.refund_enabled = provider.refund_enabled
form.allow_user_refund = provider.allow_user_refund
clearConfig()
- // Pre-fill config from API response (non-sensitive in cleartext, sensitive masked as ••••••••)
+ // Pre-fill config from API response. Backend omits sensitive fields entirely,
+ // so those inputs stay blank — submitting blank preserves the stored secret.
if (provider.config) {
for (const [k, v] of Object.entries(provider.config)) {
// Skip notifyUrl/returnUrl — they are derived from callbackBaseUrl
--
GitLab
From 6530776a62dd355ccd3f9bca1ca3ed41f715bdd7 Mon Sep 17 00:00:00 2001
From: erio
Date: Sun, 19 Apr 2026 18:05:25 +0800
Subject: [PATCH 023/261] fix: support xhigh reasoning effort in usage records
for Claude Messages API
Closes #1732
---
backend/internal/service/gateway_request.go | 2 +-
backend/internal/service/gateway_request_test.go | 8 +++++++-
frontend/src/utils/format.ts | 4 +++-
3 files changed, 11 insertions(+), 3 deletions(-)
diff --git a/backend/internal/service/gateway_request.go b/backend/internal/service/gateway_request.go
index 55cb2c84..498336a4 100644
--- a/backend/internal/service/gateway_request.go
+++ b/backend/internal/service/gateway_request.go
@@ -962,7 +962,7 @@ func NormalizeClaudeOutputEffort(raw string) *string {
return nil
}
switch value {
- case "low", "medium", "high", "max":
+ case "low", "medium", "high", "xhigh", "max":
return &value
default:
return nil
diff --git a/backend/internal/service/gateway_request_test.go b/backend/internal/service/gateway_request_test.go
index d262456d..40bd1186 100644
--- a/backend/internal/service/gateway_request_test.go
+++ b/backend/internal/service/gateway_request_test.go
@@ -1149,6 +1149,11 @@ func TestParseGatewayRequest_OutputEffort(t *testing.T) {
body: `{"model":"claude-opus-4-6","output_config":{"effort":"max"},"messages":[]}`,
wantEffort: "max",
},
+ {
+ name: "output_config.effort xhigh",
+ body: `{"model":"claude-opus-4-7","output_config":{"effort":"xhigh"},"messages":[]}`,
+ wantEffort: "xhigh",
+ },
{
name: "output_config without effort",
body: `{"model":"claude-opus-4-6","output_config":{},"messages":[]}`,
@@ -1186,9 +1191,10 @@ func TestNormalizeClaudeOutputEffort(t *testing.T) {
{"LOW", strPtr("low")},
{"Max", strPtr("max")},
{" medium ", strPtr("medium")},
+ {"xhigh", strPtr("xhigh")},
+ {"XHIGH", strPtr("xhigh")},
{"", nil},
{"unknown", nil},
- {"xhigh", nil},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
diff --git a/frontend/src/utils/format.ts b/frontend/src/utils/format.ts
index ea5d597b..6f13a065 100644
--- a/frontend/src/utils/format.ts
+++ b/frontend/src/utils/format.ts
@@ -193,7 +193,9 @@ export function formatReasoningEffort(effort: string | null | undefined): string
return 'High'
case 'xhigh':
case 'extrahigh':
- return 'Xhigh'
+ return 'XHigh'
+ case 'max':
+ return 'Max'
case 'none':
case 'minimal':
return '-'
--
GitLab
From 258fd145ff8a2355befc03166f34c3d498da7aa8 Mon Sep 17 00:00:00 2001
From: erio
Date: Sun, 19 Apr 2026 18:45:04 +0800
Subject: [PATCH 024/261] fix(account): prevent quota-exceeded API key/Bedrock
accounts from being scheduled
Add quota exceeded check to IsSchedulable() and refactor
shouldClearStickySession to delegate to IsSchedulable(), eliminating
duplicated logic and fixing missed overload/rate-limit/expired checks.
Frontend displays quota exceeded status independently via quota fields.
---
backend/internal/service/account.go | 3 +
.../service/account_quota_schedulable_test.go | 123 ++++++++++++++++++
backend/internal/service/gateway_service.go | 15 +--
.../internal/service/sticky_session_test.go | 66 ++++++++--
.../account/AccountStatusIndicator.vue | 33 +++--
frontend/src/i18n/locales/en.ts | 1 +
frontend/src/i18n/locales/zh.ts | 1 +
7 files changed, 207 insertions(+), 35 deletions(-)
create mode 100644 backend/internal/service/account_quota_schedulable_test.go
diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go
index 52db3073..af686ae7 100644
--- a/backend/internal/service/account.go
+++ b/backend/internal/service/account.go
@@ -121,6 +121,9 @@ func (a *Account) IsSchedulable() bool {
if a.TempUnschedulableUntil != nil && now.Before(*a.TempUnschedulableUntil) {
return false
}
+ if a.IsAPIKeyOrBedrock() && a.IsQuotaExceeded() {
+ return false
+ }
return true
}
diff --git a/backend/internal/service/account_quota_schedulable_test.go b/backend/internal/service/account_quota_schedulable_test.go
new file mode 100644
index 00000000..2895b34c
--- /dev/null
+++ b/backend/internal/service/account_quota_schedulable_test.go
@@ -0,0 +1,123 @@
+//go:build unit
+
+package service
+
+import (
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestAccountIsSchedulable_QuotaExceeded(t *testing.T) {
+ now := time.Now()
+
+ tests := []struct {
+ name string
+ account *Account
+ want bool
+ }{
+ {
+ name: "apikey daily quota exceeded",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{
+ "quota_daily_limit": 10.0,
+ "quota_daily_used": 10.0,
+ "quota_daily_start": now.Add(-1 * time.Hour).Format(time.RFC3339),
+ },
+ },
+ want: false,
+ },
+ {
+ name: "apikey weekly quota exceeded",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{
+ "quota_weekly_limit": 50.0,
+ "quota_weekly_used": 50.0,
+ "quota_weekly_start": now.Add(-2 * 24 * time.Hour).Format(time.RFC3339),
+ },
+ },
+ want: false,
+ },
+ {
+ name: "apikey total quota exceeded",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{
+ "quota_limit": 100.0,
+ "quota_used": 100.0,
+ },
+ },
+ want: false,
+ },
+ {
+ name: "apikey quota not exceeded",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{
+ "quota_daily_limit": 10.0,
+ "quota_daily_used": 5.0,
+ "quota_daily_start": now.Add(-1 * time.Hour).Format(time.RFC3339),
+ },
+ },
+ want: true,
+ },
+ {
+ name: "apikey expired daily period restores schedulable",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{
+ "quota_daily_limit": 10.0,
+ "quota_daily_used": 10.0,
+ "quota_daily_start": now.Add(-25 * time.Hour).Format(time.RFC3339),
+ },
+ },
+ want: true,
+ },
+ {
+ name: "oauth ignores quota exceeded",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ Type: AccountTypeOAuth,
+ Extra: map[string]any{
+ "quota_daily_limit": 10.0,
+ "quota_daily_used": 10.0,
+ "quota_daily_start": now.Add(-1 * time.Hour).Format(time.RFC3339),
+ },
+ },
+ want: true,
+ },
+ {
+ name: "bedrock quota exceeded",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ Type: AccountTypeBedrock,
+ Extra: map[string]any{
+ "quota_limit": 200.0,
+ "quota_used": 200.0,
+ },
+ },
+ want: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ require.Equal(t, tt.want, tt.account.IsSchedulable())
+ })
+ }
+}
diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go
index 07a9e41c..5a91d0de 100644
--- a/backend/internal/service/gateway_service.go
+++ b/backend/internal/service/gateway_service.go
@@ -435,26 +435,19 @@ func prefetchedStickyAccountIDFromContext(ctx context.Context, groupID *int64) i
}
// shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。
-// 当账号状态为错误、禁用、不可调度、处于临时不可调度期间,
-// 或请求的模型处于限流状态时,返回 true。
-// 这确保后续请求不会继续使用不可用的账号。
+// 委托 IsSchedulable() 判断账号级可调度性(状态、配额、过载、限流等),
+// 额外检查模型级限流。
//
// shouldClearStickySession checks if an account is in an unschedulable state
// and the sticky session binding should be cleared.
-// Returns true when account status is error/disabled, schedulable is false,
-// within temporary unschedulable period, or the requested model is rate-limited.
-// This ensures subsequent requests won't continue using unavailable accounts.
+// Delegates to IsSchedulable() for account-level checks, plus model-level rate limiting.
func shouldClearStickySession(account *Account, requestedModel string) bool {
if account == nil {
return false
}
- if account.Status == StatusError || account.Status == StatusDisabled || !account.Schedulable {
+ if !account.IsSchedulable() {
return true
}
- if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) {
- return true
- }
- // 检查模型限流和 scope 限流,有限流即清除粘性会话
if remaining := account.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel); remaining > 0 {
return true
}
diff --git a/backend/internal/service/sticky_session_test.go b/backend/internal/service/sticky_session_test.go
index e7ef8982..11ace7bd 100644
--- a/backend/internal/service/sticky_session_test.go
+++ b/backend/internal/service/sticky_session_test.go
@@ -15,20 +15,8 @@ import (
"github.com/stretchr/testify/require"
)
-// TestShouldClearStickySession 测试粘性会话清理判断逻辑。
-// 验证在以下情况下是否正确判断需要清理粘性会话:
-// - nil 账号:不清理(返回 false)
-// - 状态为错误或禁用:清理
-// - 不可调度:清理
-// - 临时不可调度且未过期:清理
-// - 临时不可调度已过期:不清理
-// - 正常可调度状态:不清理
-// - 模型限流(任意时长):清理
-//
-// TestShouldClearStickySession tests the sticky session clearing logic.
-// Verifies correct behavior for various account states including:
-// nil account, error/disabled status, unschedulable, temporary unschedulable,
-// and model rate limiting scenarios.
+// TestShouldClearStickySession tests sticky session clearing via IsSchedulable() delegation
+// plus model-level rate limiting.
func TestShouldClearStickySession(t *testing.T) {
now := time.Now()
future := now.Add(1 * time.Hour)
@@ -101,6 +89,56 @@ func TestShouldClearStickySession(t *testing.T) {
requestedModel: "claude-opus-4", // 请求不同模型
want: false, // 不同模型不受影响
},
+ {
+ name: "apikey quota exceeded",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{
+ "quota_daily_limit": 10.0,
+ "quota_daily_used": 10.0,
+ "quota_daily_start": now.Add(-1 * time.Hour).Format(time.RFC3339),
+ },
+ },
+ requestedModel: "",
+ want: true,
+ },
+ {
+ name: "oauth quota exceeded not cleared",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ Type: AccountTypeOAuth,
+ Extra: map[string]any{
+ "quota_daily_limit": 10.0,
+ "quota_daily_used": 10.0,
+ "quota_daily_start": now.Add(-1 * time.Hour).Format(time.RFC3339),
+ },
+ },
+ requestedModel: "",
+ want: false,
+ },
+ {
+ name: "overloaded account",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ OverloadUntil: &future,
+ },
+ requestedModel: "",
+ want: true,
+ },
+ {
+ name: "account-level rate limited",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ RateLimitResetAt: &future,
+ },
+ requestedModel: "",
+ want: true,
+ },
}
for _, tt := range tests {
diff --git a/frontend/src/components/account/AccountStatusIndicator.vue b/frontend/src/components/account/AccountStatusIndicator.vue
index fc2f7d0c..dd38a49f 100644
--- a/frontend/src/components/account/AccountStatusIndicator.vue
+++ b/frontend/src/components/account/AccountStatusIndicator.vue
@@ -284,6 +284,16 @@ const hasError = computed(() => {
return props.account.status === 'error'
})
+const isQuotaExceeded = computed(() => {
+ const exceeded = (used?: number | null, limit?: number | null) =>
+ typeof limit === 'number' && limit > 0 && typeof used === 'number' && used >= limit
+ return (
+ exceeded(props.account.quota_used, props.account.quota_limit) ||
+ exceeded(props.account.quota_daily_used, props.account.quota_daily_limit) ||
+ exceeded(props.account.quota_weekly_used, props.account.quota_weekly_limit)
+ )
+})
+
// Computed: countdown text for rate limit (429)
const rateLimitCountdown = computed(() => {
return formatCountdown(props.account.rate_limit_reset_at)
@@ -307,19 +317,16 @@ const statusClass = computed(() => {
if (isTempUnschedulable.value) {
return 'badge-warning'
}
+ if (props.account.status !== 'active') {
+ return props.account.status === 'error' ? 'badge-danger' : 'badge-gray'
+ }
+ if (isQuotaExceeded.value) {
+ return 'badge-warning'
+ }
if (!props.account.schedulable) {
return 'badge-gray'
}
- switch (props.account.status) {
- case 'active':
- return 'badge-success'
- case 'inactive':
- return 'badge-gray'
- case 'error':
- return 'badge-danger'
- default:
- return 'badge-gray'
- }
+ return 'badge-success'
})
// Computed: status text
@@ -330,6 +337,12 @@ const statusText = computed(() => {
if (isTempUnschedulable.value) {
return t('admin.accounts.status.tempUnschedulable')
}
+ if (props.account.status !== 'active') {
+ return t(`admin.accounts.status.${props.account.status}`)
+ }
+ if (isQuotaExceeded.value) {
+ return t('admin.accounts.status.quotaExceeded')
+ }
if (!props.account.schedulable) {
return t('admin.accounts.status.paused')
}
diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts
index d1def45c..c0a17d96 100644
--- a/frontend/src/i18n/locales/en.ts
+++ b/frontend/src/i18n/locales/en.ts
@@ -2126,6 +2126,7 @@ export default {
rateLimited: 'Rate Limited',
overloaded: 'Overloaded',
tempUnschedulable: 'Temp Unschedulable',
+ quotaExceeded: 'Quota Exceeded',
unschedulable: 'Unschedulable',
rateLimitedUntil: 'Rate limited and removed from scheduling. Auto resumes at {time}',
rateLimitedAutoResume: 'Auto resumes in {time}',
diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts
index 6f57ab3e..ba9edd7f 100644
--- a/frontend/src/i18n/locales/zh.ts
+++ b/frontend/src/i18n/locales/zh.ts
@@ -2315,6 +2315,7 @@ export default {
rateLimited: '限流中',
overloaded: '过载中',
tempUnschedulable: '临时不可调度',
+ quotaExceeded: '配额超限',
unschedulable: '不可调度',
rateLimitedUntil: '限流中,当前不参与调度,预计 {time} 自动恢复',
rateLimitedAutoResume: '{time} 自动恢复',
--
GitLab
From 6579f28b649486cd3bfc9de3f2ed442f87707daf Mon Sep 17 00:00:00 2001
From: erio
Date: Sun, 19 Apr 2026 20:38:57 +0800
Subject: [PATCH 025/261] fix: delete scheduled test plans when account is
deleted
Accounts use soft-delete (setting deleted_at), so PostgreSQL's
ON DELETE CASCADE on scheduled_test_plans.account_id never fires.
Add plan deletion to the existing account deletion transaction
to ensure atomicity.
Closes Wei-Shaw/sub2api#1728
---
backend/internal/repository/account_repo.go | 3 +++
1 file changed, 3 insertions(+)
diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go
index 24115c33..78f739ac 100644
--- a/backend/internal/repository/account_repo.go
+++ b/backend/internal/repository/account_repo.go
@@ -438,6 +438,9 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error {
if _, err := txClient.AccountGroup.Delete().Where(dbaccountgroup.AccountIDEQ(id)).Exec(ctx); err != nil {
return err
}
+ if _, err := txClient.ExecContext(ctx, "DELETE FROM scheduled_test_plans WHERE account_id = $1", id); err != nil {
+ return err
+ }
if _, err := txClient.Account.Delete().Where(dbaccount.IDEQ(id)).Exec(ctx); err != nil {
return err
}
--
GitLab
From 23def40bc5415c04ca3a05bb6d67a6ff1e4a3566 Mon Sep 17 00:00:00 2001
From: shaw
Date: Sun, 19 Apr 2026 22:06:04 +0800
Subject: [PATCH 026/261] chore: change license from MIT to LGPL v3.0
---
LICENSE | 186 +++++++++++++++++++++++++++++++++++++++++++++------
README.md | 4 +-
README_CN.md | 4 +-
README_JA.md | 4 +-
4 files changed, 174 insertions(+), 24 deletions(-)
diff --git a/LICENSE b/LICENSE
index 7a94ca9d..153d416d 100644
--- a/LICENSE
+++ b/LICENSE
@@ -1,21 +1,165 @@
-MIT License
-
-Copyright (c) 2025 Wesley Liddick
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
+ GNU LESSER GENERAL PUBLIC LICENSE
+ Version 3, 29 June 2007
+
+ Copyright (C) 2007 Free Software Foundation, Inc.
+ Everyone is permitted to copy and distribute verbatim copies
+ of this license document, but changing it is not allowed.
+
+
+ This version of the GNU Lesser General Public License incorporates
+the terms and conditions of version 3 of the GNU General Public
+License, supplemented by the additional permissions listed below.
+
+ 0. Additional Definitions.
+
+ As used herein, "this License" refers to version 3 of the GNU Lesser
+General Public License, and the "GNU GPL" refers to version 3 of the GNU
+General Public License.
+
+ "The Library" refers to a covered work governed by this License,
+other than an Application or a Combined Work as defined below.
+
+ An "Application" is any work that makes use of an interface provided
+by the Library, but which is not otherwise based on the Library.
+Defining a subclass of a class defined by the Library is deemed a mode
+of using an interface provided by the Library.
+
+ A "Combined Work" is a work produced by combining or linking an
+Application with the Library. The particular version of the Library
+with which the Combined Work was made is also called the "Linked
+Version".
+
+ The "Minimal Corresponding Source" for a Combined Work means the
+Corresponding Source for the Combined Work, excluding any source code
+for portions of the Combined Work that, considered in isolation, are
+based on the Application, and not on the Linked Version.
+
+ The "Corresponding Application Code" for a Combined Work means the
+object code and/or source code for the Application, including any data
+and utility programs needed for reproducing the Combined Work from the
+Application, but excluding the System Libraries of the Combined Work.
+
+ 1. Exception to Section 3 of the GNU GPL.
+
+ You may convey a covered work under sections 3 and 4 of this License
+without being bound by section 3 of the GNU GPL.
+
+ 2. Conveying Modified Versions.
+
+ If you modify a copy of the Library, and, in your modifications, a
+facility refers to a function or data to be supplied by an Application
+that uses the facility (other than as an argument passed when the
+facility is invoked), then you may convey a copy of the modified
+version:
+
+ a) under this License, provided that you make a good faith effort to
+ ensure that, in the event an Application does not supply the
+ function or data, the facility still operates, and performs
+ whatever part of its purpose remains meaningful, or
+
+ b) under the GNU GPL, with none of the additional permissions of
+ this License applicable to that copy.
+
+ 3. Object Code Incorporating Material from Library Header Files.
+
+ The object code form of an Application may incorporate material from
+a header file that is part of the Library. You may convey such object
+code under terms of your choice, provided that, if the incorporated
+material is not limited to numerical parameters, data structure
+layouts and accessors, or small macros, inline functions and templates
+(ten or fewer lines in length), you do both of the following:
+
+ a) Give prominent notice with each copy of the object code that the
+ Library is used in it and that the Library and its use are
+ covered by this License.
+
+ b) Accompany the object code with a copy of the GNU GPL and this license
+ document.
+
+ 4. Combined Works.
+
+ You may convey a Combined Work under terms of your choice that,
+taken together, effectively do not restrict modification of the
+portions of the Library contained in the Combined Work and reverse
+engineering for debugging such modifications, if you also do each of
+the following:
+
+ a) Give prominent notice with each copy of the Combined Work that
+ the Library is used in it and that the Library and its use are
+ covered by this License.
+
+ b) Accompany the Combined Work with a copy of the GNU GPL and this license
+ document.
+
+ c) For a Combined Work that displays copyright notices during
+ execution, include the copyright notice for the Library among
+ these notices, as well as a reference directing the user to the
+ copies of the GNU GPL and this license document.
+
+ d) Do one of the following:
+
+ 0) Convey the Minimal Corresponding Source under the terms of this
+ License, and the Corresponding Application Code in a form
+ suitable for, and under terms that permit, the user to
+ recombine or relink the Application with a modified version of
+ the Linked Version to produce a modified Combined Work, in the
+ manner specified by section 6 of the GNU GPL for conveying
+ Corresponding Source.
+
+ 1) Use a suitable shared library mechanism for linking with the
+ Library. A suitable mechanism is one that (a) uses at run time
+ a copy of the Library already present on the user's computer
+ system, and (b) will operate properly with a modified version
+ of the Library that is interface-compatible with the Linked
+ Version.
+
+ e) Provide Installation Information, but only if you would otherwise
+ be required to provide such information under section 6 of the
+ GNU GPL, and only to the extent that such information is
+ necessary to install and execute a modified version of the
+ Combined Work produced by recombining or relinking the
+ Application with a modified version of the Linked Version. (If
+ you use option 4d0, the Installation Information must accompany
+ the Minimal Corresponding Source and Corresponding Application
+ Code. If you use option 4d1, you must provide the Installation
+ Information in the manner specified by section 6 of the GNU GPL
+ for conveying Corresponding Source.)
+
+ 5. Combined Libraries.
+
+ You may place library facilities that are a work based on the
+Library side by side in a single library together with other library
+facilities that are not Applications and are not covered by this
+License, and convey such a combined library under terms of your
+choice, if you do both of the following:
+
+ a) Accompany the combined library with a copy of the same work based
+ on the Library, uncombined with any other library facilities,
+ conveyed under the terms of this License.
+
+ b) Give prominent notice with the combined library that part of it
+ is a work based on the Library, and explaining where to find the
+ accompanying uncombined form of the same work.
+
+ 6. Revised Versions of the GNU Lesser General Public License.
+
+ The Free Software Foundation may publish revised and/or new versions
+of the GNU Lesser General Public License from time to time. Such new
+versions will be similar in spirit to the present version, but may
+differ in detail to address new problems or concerns.
+
+ Each version is given a distinguishing version number. If the
+Library as you received it specifies that a certain numbered version
+of the GNU Lesser General Public License "or any later version"
+applies to it, you have the option of following the terms and
+conditions either of that published version or of any later version
+published by the Free Software Foundation. If the Library as you
+received it does not specify a version number of the GNU Lesser
+General Public License, you may choose any version of the GNU Lesser
+General Public License ever published by the Free Software Foundation.
+
+ If the Library as you received it specifies that a proxy can decide
+whether future versions of the GNU Lesser General Public License shall
+apply, that proxy's public statement of acceptance of any version is
+permanent authorization for you to choose that version for the
+Library.
\ No newline at end of file
diff --git a/README.md b/README.md
index 74ab9af2..bee2e8c3 100644
--- a/README.md
+++ b/README.md
@@ -618,7 +618,9 @@ sub2api/
## License
-MIT License
+This project is licensed under the [GNU Lesser General Public License v3.0](LICENSE) (or later).
+
+Copyright (c) 2026 Wesley Liddick
---
diff --git a/README_CN.md b/README_CN.md
index c701372c..892eee61 100644
--- a/README_CN.md
+++ b/README_CN.md
@@ -679,7 +679,9 @@ sub2api/
## 许可证
-MIT License
+本项目基于 [GNU 宽通用公共许可证 v3.0](LICENSE)(或更高版本)授权。
+
+Copyright (c) 2026 Wesley Liddick
---
diff --git a/README_JA.md b/README_JA.md
index 0d4db616..6f0fc900 100644
--- a/README_JA.md
+++ b/README_JA.md
@@ -617,7 +617,9 @@ sub2api/
## ライセンス
-MIT License
+本プロジェクトは [GNU Lesser General Public License v3.0](LICENSE)(またはそれ以降のバージョン)の下でライセンスされています。
+
+Copyright (c) 2026 Wesley Liddick
---
--
GitLab
From e01c1eacebc4b4d7cce25de83fb2cb94a76c10fb Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Mon, 20 Apr 2026 13:18:30 +0800
Subject: [PATCH 027/261] docs: add auth identity payment foundation design
spec
---
...auth-identity-payment-foundation-design.md | 540 ++++++++++++++++++
1 file changed, 540 insertions(+)
create mode 100644 docs/superpowers/specs/2026-04-20-auth-identity-payment-foundation-design.md
diff --git a/docs/superpowers/specs/2026-04-20-auth-identity-payment-foundation-design.md b/docs/superpowers/specs/2026-04-20-auth-identity-payment-foundation-design.md
new file mode 100644
index 00000000..a60cfdca
--- /dev/null
+++ b/docs/superpowers/specs/2026-04-20-auth-identity-payment-foundation-design.md
@@ -0,0 +1,540 @@
+# Auth Identity And Payment Foundation Design
+
+**Date:** 2026-04-20
+
+**Status:** Draft approved in conversation, written for implementation planning
+
+**Goal**
+
+Rebuild the `feat/auth-identity-foundation` intent on a clean branch from `main`, covering unified user identity, third-party login and binding, profile adoption, source-based signup defaults, unified payment routing and UX, admin configuration, compatibility with existing `main` data, and an opt-in OpenAI advanced scheduling switch.
+
+## Scope
+
+This design includes:
+
+- Email login and registration
+- Third-party login and binding for `LinuxDo`, `OIDC`, and `WeChat`
+- Unified identity storage for email and third-party identities
+- Pending auth sessions for callback-to-login/register/bind continuation
+- User-controlled nickname/avatar adoption during first relevant third-party flow
+- Profile binding management and avatar upload/delete
+- Source-based initial grants for balance, concurrency, and subscriptions
+- User management support for `last_login_at` and `last_active_at` sorting
+- Unified payment display methods (`alipay`, `wechat`) mapped to a single active backend source each
+- Alipay and WeChat UX routing rules across PC, mobile, H5, and WeChat environments
+- Admin settings for auth providers, source defaults, payment sources, and OpenAI advanced scheduling
+- Incremental migration and compatibility for existing email users and historical LinuxDo synthetic-email users
+
+This design does not treat unrelated upstream merges, docs churn, or license changes from the old branch as required scope.
+
+## Product Rules
+
+### Auth and identity
+
+- Existing email users remain valid and continue to log in with no manual action.
+- Third-party first login behavior:
+ - Existing bound identity: direct login
+ - Missing identity: start first-login flow
+- If `force_email_on_third_party_signup` is disabled, a first-login user may create an account without binding an email.
+- If `force_email_on_third_party_signup` is enabled, the user must provide an email.
+- If the provided and verified email already exists:
+ - show that the email already exists
+ - allow "verify and bind existing account"
+ - allow "change email and continue registration"
+ - do not allow bypassing the email requirement
+- Upstream provider email verification is not trusted as a local bound email.
+- WeChat login chooses channel by environment:
+ - in WeChat environment: `mp`
+ - outside WeChat: `open`
+- WeChat primary identity key is `unionid`.
+- If a WeChat login/bind flow cannot produce `unionid`, the flow fails and no fallback `openid` identity is created.
+
+### Profile adoption
+
+- During the first relevant third-party flow, the user can independently decide:
+ - replace current nickname or not
+ - replace current avatar or not
+- This applies to first third-party registration and first third-party binding.
+- The decision is explicit user choice, not automatic replacement.
+
+### Source-based initial grants
+
+- Source-specific defaults exist for `email`, `linuxdo`, `oidc`, and `wechat`.
+- Each source defines:
+ - default balance
+ - default concurrency
+ - default subscriptions
+ - grant on signup
+ - grant on first bind
+- Default behavior:
+ - grant on signup: enabled
+ - grant on first bind: disabled
+- First-bind grants are optional and controlled per source.
+- Grants must be idempotent.
+
+### Avatar management
+
+- Avatar supports:
+ - external URL
+ - image `data:` URL
+- `data:` URL images are compressed to at most `100KB` before persistence.
+- Avatar storage is database-backed.
+- Avatar delete is supported.
+
+### Payment UX and routing
+
+- Frontend shows only two display methods:
+ - `alipay`
+ - `wechat`
+- Users never choose between official providers and EasyPay explicitly.
+- Backend allows only one active source per display method at a time.
+- Alipay UX:
+ - PC: show QR code in page
+ - mobile: jump to Alipay app/payment flow
+- WeChat UX:
+ - PC: show QR code in page
+ - non-WeChat H5: prefer H5 pay; if unavailable, tell the user to open in WeChat
+ - WeChat environment: prefer MP/JSAPI pay; if unavailable, fall back to H5 pay
+- Payment success is confirmed by backend order state, webhook, and/or query, not only frontend return.
+
+### OpenAI advanced scheduling
+
+- OpenAI advanced scheduling is supported.
+- It is disabled by default.
+- Admin can enable it explicitly.
+
+## Architecture
+
+Keep `users` as the account owner table and move login identities, channel mappings, pending auth state, and first-bind grant idempotency into dedicated tables and services. Keep email login working while progressively introducing unified identity reads and writes.
+
+Payment uses a similar split between user-visible display methods and backend provider sources. Frontend works only with stable display methods while backend resolves to the currently active source and capability matrix.
+
+Compatibility is a first-class concern: migrations are additive, reads are compatibility-aware, and rollout must tolerate existing `main` data and short-lived frontend/backend version skew.
+
+## Data Model
+
+### `users`
+
+Preserve existing account ownership and local-login fields. Extend or use:
+
+- `email`
+- `password_hash`
+- `totp_enabled`
+- `signup_source`
+- `last_login_at`
+- `last_active_at`
+
+The `users` table remains the primary business subject for balance, concurrency, subscriptions, permissions, and profile.
+
+### `auth_identities`
+
+Represents all canonical login or bindable identities.
+
+Fields:
+
+- `user_id`
+- `provider_type`: `email`, `linuxdo`, `oidc`, `wechat`
+- `provider_key`
+- `provider_subject`
+- `verified_at`
+- `issuer`
+- `metadata`
+- timestamps
+
+Uniqueness:
+
+- `provider_type + provider_key + provider_subject` must be unique
+
+Rules:
+
+- email identity uses canonicalized local email
+- LinuxDo uses stable provider subject
+- OIDC uses stable issuer + subject
+- WeChat uses `unionid` as canonical subject
+
+### `auth_identity_channels`
+
+Stores channel-specific subject mappings for an identity.
+
+Primary use:
+
+- WeChat `open` / `mp` / payment channel mapping
+
+Fields:
+
+- `identity_id`
+- `provider_type`
+- `provider_key`
+- `channel`
+- `channel_app_id`
+- `channel_subject`
+- `metadata`
+- timestamps
+
+Rules:
+
+- canonical WeChat identity still keys on `unionid`
+- `openid` values live here as channel mappings
+
+### `pending_auth_sessions`
+
+Stores callback state between third-party callback and final account action.
+
+Fields:
+
+- `intent`
+- `provider_type`
+- `provider_key`
+- `provider_subject`
+- `target_user_id`
+- `redirect_to`
+- `resolved_email`
+- `pending_password_hash`
+- `upstream_identity_payload`
+- `metadata`
+- `email_verified_at`
+- `password_verified_at`
+- `totp_verified_at`
+- `expires_at`
+- `consumed_at`
+- timestamps
+
+Responsibilities:
+
+- continue provider callback into register/login/bind flows
+- persist nickname/avatar suggestions
+- persist explicit adoption decisions
+- survive navigation between auth pages
+
+### `identity_adoption_decisions`
+
+Persists user adoption preference for a specific identity.
+
+Fields:
+
+- `identity_id`
+- `adopt_display_name`
+- `adopt_avatar`
+- `decided_at`
+- timestamps
+
+### `user_avatars`
+
+Stores the currently effective custom avatar.
+
+Fields:
+
+- `user_id`
+- `storage_provider`
+- `storage_key`
+- `url`
+- `content_type`
+- `byte_size`
+- `sha256`
+- timestamps
+
+Rules:
+
+- supports URL-backed and inline data-backed representations
+- hard maximum payload size is `100KB`
+
+### `user_provider_default_grants`
+
+Stores idempotency state for source grants.
+
+Fields:
+
+- `user_id`
+- `provider_type`
+- `granted_at`
+- timestamps
+
+Responsibilities:
+
+- prevent duplicate first-bind grants
+- allow signup grants and first-bind grants to be reasoned about independently
+
+## Identity Keys And Canonicalization
+
+- Email canonical key: `lower(trim(email))`
+- LinuxDo canonical key: provider subject from LinuxDo
+- OIDC canonical key: `issuer + sub`
+- WeChat canonical key: `unionid`
+
+WeChat-specific rule:
+
+- `openid` never becomes the primary stored identity key
+- if only `openid` is available, login/bind fails with a configuration/identity error
+
+## Core Flows
+
+### Email register/login
+
+- Existing email auth flow remains
+- On email registration, create canonical `email` identity
+- Apply `email` source signup defaults
+
+### Third-party login with existing identity
+
+- Resolve canonical identity
+- Login mapped `user`
+- Update `last_login_at`
+- Do not issue signup or first-bind grants again
+
+### Third-party first login with no identity
+
+- Create `pending_auth_session`
+- Frontend callback flow decides next action
+
+Branches:
+
+- no forced email binding:
+ - user can create account directly
+- forced email binding:
+ - user must supply local email
+
+If supplied local email already exists:
+
+- tell the user the email already exists
+- allow verify-and-bind-existing-account
+- allow changing email to continue registration
+
+On new account creation:
+
+- create `users` row
+- create canonical third-party identity
+- apply source signup grants
+- apply adoption choices if selected
+
+### Bind third-party identity to current logged-in user
+
+- current user starts bind flow
+- callback resolves to `bind_current_user`
+- bind canonical identity to current user
+- if configured and first bind for that provider, apply first-bind grants
+- present nickname/avatar replacement choice
+
+### Bind existing account during first-login flow
+
+- verify password for existing account
+- if account requires TOTP, verify TOTP
+- bind canonical identity to target account
+- optionally apply first-bind grants
+- present nickname/avatar replacement choice
+
+### WeChat login and channel mapping
+
+- environment chooses `mp` or `open`
+- callback must resolve to `unionid`
+- channel `openid` is optionally recorded in `auth_identity_channels`
+- failure to obtain `unionid` aborts flow
+
+### Avatar upload and delete
+
+- URL avatar: validate and persist reference
+- data URL avatar:
+ - decode
+ - validate image type
+ - compress to `<=100KB`
+ - persist database-backed inline representation
+- delete removes current custom avatar entry
+
+## Payment Routing Model
+
+### User-visible methods
+
+- `alipay`
+- `wechat`
+
+### Backend source abstraction
+
+Each display method maps to exactly one active configured backend source:
+
+- `official_alipay`
+- `easypay_alipay`
+- `official_wechat`
+- `easypay_wechat`
+
+Frontend submits display method only. Backend resolves display method to active source and capability set.
+
+### Alipay routing
+
+- PC: create QR-oriented result and show QR in page
+- mobile: create jump/redirect-oriented result
+
+### WeChat routing
+
+- PC: QR result
+- non-WeChat H5:
+ - prefer H5 pay
+ - if unavailable, show "open in WeChat" requirement
+- WeChat environment:
+ - prefer MP/JSAPI
+ - if unavailable, fall back to H5 pay
+
+### Payment completion
+
+- frontend return restores context and UI state
+- backend order state remains source of truth
+- webhook and/or order query remain authoritative for fulfillment
+
+## Admin Configuration Model
+
+### Auth provider settings
+
+- email registration and verification settings
+- force email on third-party signup
+- LinuxDo client settings
+- OIDC issuer/client settings and provider display name
+- WeChat `open` and `mp` settings with config-valid and health indicators
+
+### Source default settings
+
+Per source (`email`, `linuxdo`, `oidc`, `wechat`):
+
+- default balance
+- default concurrency
+- default subscriptions
+- grant on signup
+- grant on first bind
+
+### Payment settings
+
+- active source for `alipay`
+- active source for `wechat`
+- source-specific credentials and enablement
+- WeChat capability matrix:
+ - QR available
+ - H5 available
+ - MP/JSAPI available
+
+### Scheduling settings
+
+- OpenAI advanced scheduling enabled/disabled
+- default disabled
+
+## Compatibility And Rollout
+
+Compatibility is mandatory, especially for:
+
+- existing email users
+- existing LinuxDo users
+- historical LinuxDo synthetic-email accounts
+
+### Additive migrations
+
+- preserve existing `users` data and behavior
+- add identity and pending-session tables
+- avoid destructive schema swaps
+
+### Migration backfill
+
+- backfill canonical `email` identities for valid existing email users
+- backfill canonical `linuxdo` identities during migration for historical synthetic-email LinuxDo users
+- backfill must be idempotent and repeatable
+
+### Compatibility reads
+
+During rollout:
+
+- read new identity model first
+- where necessary, retain compatibility logic for existing email and historical LinuxDo synthetic-email recognition
+
+### Grant idempotency
+
+- migration backfill must not trigger signup or first-bind grants
+- first-bind grants must use explicit idempotency tracking
+
+### API compatibility
+
+Retain transitional support for legacy/new request and response shapes where needed, including:
+
+- `pending_auth_token`
+- `pending_oauth_token`
+- old callback parsing expectations
+- historical profile field mappings
+
+### Settings and payment compatibility
+
+- preserve existing payment configs and order semantics from `main`
+- add new settings incrementally
+- avoid rewriting the entire settings schema in one cutover
+
+### Rolling upgrade tolerance
+
+- do not assume simultaneous frontend/backend deployment
+- new backend must tolerate short-lived older frontend request shapes
+
+## Testing Strategy
+
+### Repository tests
+
+- identity upsert and lookup
+- WeChat channel mapping
+- pending auth session persistence
+- source grant idempotency
+- avatar persistence and delete
+- migration backfill behavior
+
+### Service tests
+
+- direct login by existing identity
+- first third-party signup
+- forced email flow
+- existing-email bind-existing-account flow
+- first-bind grant on/off
+- nickname/avatar adoption choices
+- WeChat `unionid` required behavior
+- payment routing resolution
+
+### Handler and route tests
+
+- LinuxDo/OIDC/WeChat callback handling
+- bind-existing
+- bind-current-user
+- create-account
+- TOTP continuation
+- payment create and recovery
+
+### Frontend tests
+
+- third-party callback flow state machine
+- register/login continuation
+- profile bindings card
+- avatar interactions
+- payment page routing behavior
+- admin settings UI
+
+### Compatibility tests
+
+- existing email users
+- historical LinuxDo synthetic-email users
+- historical payment config
+- legacy auth payload field names
+- historical payment result handling
+
+## Implementation Phases
+
+1. Add schema, migrations, compatibility backfill, and repository support
+2. Implement unified identity services and pending auth session flows
+3. Integrate profile binding, avatar, and adoption decision flows
+4. Add per-source default grants and admin config surfaces
+5. Rebuild payment routing abstraction and frontend payment UX
+6. Add user-management sorting and OpenAI advanced scheduling switch
+7. Run compatibility, rollout, and regression hardening
+
+## External Constraints And Best Practices
+
+Implementation must follow current primary-source guidance:
+
+- OAuth 2.0 Security BCP (RFC 9700): strict redirect handling, state protection, mix-up resistant design
+- PKCE (RFC 7636): use on authorization code flows where applicable
+- OpenID Connect Core: stable issuer/subject handling for OIDC identities
+- Account linking best practice: require explicit user confirmation or re-authentication before linking to existing accounts
+
+References:
+
+- RFC 9700:
+- RFC 7636:
+- OpenID Connect Core 1.0:
+- Auth0 account linking guidance:
--
GitLab
From 721d7ab3ab7b9c0ad6c96ec3e0ed23d9c7062951 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Mon, 20 Apr 2026 13:40:31 +0800
Subject: [PATCH 028/261] docs: add audit synthesis to auth identity spec
---
...auth-identity-payment-foundation-design.md | 124 ++++++++++++++++++
1 file changed, 124 insertions(+)
diff --git a/docs/superpowers/specs/2026-04-20-auth-identity-payment-foundation-design.md b/docs/superpowers/specs/2026-04-20-auth-identity-payment-foundation-design.md
index a60cfdca..bfa250d9 100644
--- a/docs/superpowers/specs/2026-04-20-auth-identity-payment-foundation-design.md
+++ b/docs/superpowers/specs/2026-04-20-auth-identity-payment-foundation-design.md
@@ -538,3 +538,127 @@ References:
- RFC 7636:
- OpenID Connect Core 1.0:
- Auth0 account linking guidance:
+
+## Audit Synthesis
+
+The clean rebuild direction is not to copy either existing branch directly.
+
+- `feat/auth-identity-foundation` has the better long-term model:
+ - unified auth identities
+ - pending auth sessions
+ - identity adoption decisions
+ - provider-scoped default grants
+ - payment display-method abstraction
+ - OpenAI advanced scheduler layering
+- `personal-dev-branch` has the better real-world closure:
+ - LinuxDo and WeChat callback flows are more operationally complete
+ - profile binding and avatar UX is more complete
+ - historical synthetic-email users are recognized and recovered in live flows
+ - WeChat payment OAuth and recovery behavior is more complete
+- Primary-source guidance supplies hard constraints for OAuth/OIDC, account linking, WeChat identity handling, and payment completion semantics.
+
+The final rebuild must therefore:
+
+- keep the `feat/auth-identity-foundation` data model direction
+- absorb the strongest business-flow behavior from `personal-dev-branch`
+- reject transitional or half-finished behavior from both branches
+- treat compatibility and rollout as first-class implementation scope
+
+## Keep / Adapt / Drop
+
+### Keep
+
+Keep these architectural choices essentially intact:
+
+- `auth_identities`, `auth_identity_channels`, `pending_auth_sessions`, `identity_adoption_decisions`
+- per-provider default grants with one-time grant tracking
+- WeChat canonical identity plus channel mapping model
+- pending-auth verification gates before final bind
+- payment visible-method abstraction (`alipay`, `wechat`) decoupled from backend provider source
+- OpenAI advanced scheduler layering and test-backed behavior
+
+Keep these operational flow ideas from `personal-dev-branch`:
+
+- LinuxDo pending identity callback flow
+- WeChat pending identity callback flow
+- profile bindings UX and “cannot disconnect last usable login method” rule
+- separate WeChat login OAuth and WeChat payment OAuth entry points
+- historical synthetic-email recognition logic as a migration bridge
+
+### Adapt
+
+These areas must be reimplemented with the same intent but stricter boundaries:
+
+- third-party account creation from pending-auth state must be transactional and must not register a plain local user before identity finalization succeeds
+- email identity lifecycle must become real dual-write state, not just one migration-time backfill
+- `signup_source` must be backfilled more accurately for known historical third-party users
+- WeChat payment recovery state must move from frontend-only storage to server-backed continuation state
+- avatar adoption fetches must be security-hardened and failure-visible
+- pending-auth payload modeling must clearly separate immutable upstream payload from mutable local metadata
+- profile binding/avatar DTOs must be simplified to one authoritative backend contract instead of sprawling frontend fallback parsing
+- admin settings should preserve capability while reducing duplicated or transitional config branches
+
+### Drop
+
+Drop these as long-term design choices:
+
+- `user_external_identities` as the primary long-term identity model
+- synthetic email as a long-term canonical identity representation
+- OIDC as a side-path that does not participate in the same identity foundation as LinuxDo and WeChat
+- frontend multi-endpoint probing and broad compatibility parsing once the clean branch becomes the sole supported contract
+- unrelated branch noise such as generated-file churn, locale-only churn, or upstream merge residue as design inputs
+
+## Audit-Driven Hard Constraints
+
+The audit and source review establish these hard constraints:
+
+### Auth
+
+- all authorization-code providers use PKCE where applicable
+- callback handling uses strict `redirect_uri` discipline and state validation
+- OIDC identity key is `issuer + sub`
+- existing-account linking after email conflict must require explicit user action plus local-account verification
+- WeChat canonical identity key is `unionid`; `openid` is channel-scoped only
+
+### Compatibility
+
+- existing email users must continue to work with no manual intervention
+- existing LinuxDo users must not split into duplicate accounts
+- historical LinuxDo synthetic-email users must be backfilled into canonical LinuxDo identities during migration, not only lazily on next login
+- migration backfills must not trigger signup or first-bind grants
+- legacy `pending_auth_token` and `pending_oauth_token` contracts must remain accepted during rollout
+- legacy auth/public setting aliases needed by older frontend builds must remain available during rollout
+- existing payment configs and historical order semantics must remain valid
+
+### Payment
+
+- frontend return pages do not determine final payment success
+- backend order state, webhook processing, and/or provider status query remain authoritative
+- each visible method (`alipay`, `wechat`) may have only one active backend source at a time
+
+## Known Risks To Eliminate In Implementation
+
+These are specifically observed problems in the existing branches that the clean rebuild must eliminate:
+
+- third-party forced-email account creation currently bypasses the provider-aware account creation path and can leave orphan local accounts if bind finalization fails
+- post-migration email accounts are not fully dual-written into `auth_identities`
+- avatar adoption currently risks silent failure and insecure outbound fetch behavior
+- pending-auth payload responsibilities are internally inconsistent
+- OIDC parity is incomplete in `personal-dev-branch`; it must become a first-class provider in the unified identity model
+- WeChat union/open/channel identity handling is conceptually correct in the feature branch but still partially transitional across the codebase
+- WeChat payment recovery in `personal-dev-branch` is frontend-local and not robust across tabs or concurrent attempts
+- the existing pending-auth migration update is too destructive to reuse unchanged in a safer rollout
+- historical provider provenance should not be permanently flattened to `signup_source = email`
+
+## Rollout Gates
+
+The rebuild is not ready for rollout until all of these are satisfied:
+
+1. Identity schema and migration chain are linearized and production-safe.
+2. Email identity backfill is complete and idempotent.
+3. Historical LinuxDo synthetic-email backfill to canonical LinuxDo identity is complete and idempotent.
+4. `signup_source` backfill is accurate for known historical provider-created users.
+5. Dual token acceptance and required legacy field aliases are present.
+6. Existing payment configs are normalized and verified against current frontend-visible capabilities.
+7. New frontend flows are verified against mixed-version backend compatibility windows.
+8. Duplicate-account creation, first-bind grants, and payment route selection have regression coverage.
--
GitLab
From b6751f7ebce43ef115ca583374b2e9c8d3d2dc28 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Mon, 20 Apr 2026 13:47:00 +0800
Subject: [PATCH 029/261] docs: add auth identity implementation plan
---
...-04-20-auth-identity-payment-foundation.md | 476 ++++++++++++++++++
1 file changed, 476 insertions(+)
create mode 100644 docs/superpowers/plans/2026-04-20-auth-identity-payment-foundation.md
diff --git a/docs/superpowers/plans/2026-04-20-auth-identity-payment-foundation.md b/docs/superpowers/plans/2026-04-20-auth-identity-payment-foundation.md
new file mode 100644
index 00000000..e8fde9c0
--- /dev/null
+++ b/docs/superpowers/plans/2026-04-20-auth-identity-payment-foundation.md
@@ -0,0 +1,476 @@
+# Auth Identity Payment Foundation Implementation Plan
+
+> **For agentic workers:** REQUIRED SUB-SKILL: Use `superpowers:subagent-driven-development` (recommended) or `superpowers:executing-plans` to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
+
+**Goal:** Rebuild the auth identity, profile binding, payment routing, and OpenAI advanced scheduler foundation on top of a clean `origin/main` branch while preserving historical compatibility for existing email users and historical LinuxDo users.
+**Architecture:** A unified identity foundation centered on durable provider subjects (`email`, `linuxdo`, `oidc`, `wechat`) and transactional pending-auth sessions; backend-owned payment source routing behind stable frontend methods (`alipay`, `wxpay`); compatibility-first migration/backfill before feature enablement.
+**Tech Stack:** Go, Gin, Ent, PostgreSQL, Redis, Vue 3, Pinia, TypeScript, Vitest, pnpm.
+
+---
+
+## Non-Negotiable Product Rules
+
+- [ ] Preserve login continuity for existing email users and historical LinuxDo users.
+- [ ] During migration, backfill historical LinuxDo synthetic-email users into explicit LinuxDo identities before first post-upgrade login.
+- [ ] Keep existing email login and add third-party login/bind for `linuxdo`, `oidc`, and `wechat`.
+- [ ] On first third-party login:
+ - identity exists: direct login.
+ - identity does not exist: start pending-auth flow.
+ - local email binding is required only when system config says so.
+ - upstream provider email verification never counts as local email verification.
+- [ ] When user-entered and locally verified email already exists:
+ - offer bind-existing-account after local re-authentication.
+ - offer change-email-and-create-new-account.
+ - when email binding is mandatory, do not allow bypass without changing to another email.
+- [ ] On first third-party login or first third-party bind, provider nickname/avatar must be presented as independent replace options for the current nickname and avatar. They are not auto-applied.
+- [ ] Source-specific initial grants must support per-source defaults for balance, concurrency, and subscriptions.
+- [ ] Default grant timing: on successful new-account creation.
+- [ ] Optional grant timing: on first successful bind for the configured source.
+- [ ] Migration/backfill must never trigger first-bind or first-signup grants retroactively.
+- [ ] Avatar profile supports:
+ - direct URL storage.
+ - image data URL upload compressed to `<=100KB` before storing in DB.
+ - explicit delete.
+- [ ] Admin user management must expose and sort by `last_login_at` and `last_active_at`.
+- [ ] WeChat login rules:
+ - WeChat environment uses MP login.
+ - non-WeChat browser uses Open/QR login.
+ - canonical identity uses `unionid`.
+ - when `unionid` is unavailable, fail the login/bind flow under the approved option-1 policy.
+- [ ] Payment UI rules:
+ - user-facing methods stay `支付宝` and `微信支付`.
+ - backend decides whether each method routes to official provider instance or EasyPay.
+ - at runtime, each visible method may only have one active source.
+- [ ] Alipay rules:
+ - PC: in-page QR.
+ - mobile browser: jump to Alipay payment.
+- [ ] WeChat Pay rules:
+ - PC: in-page QR.
+ - WeChat H5: MP/JSAPI first, fallback to H5 pay.
+ - non-WeChat H5: H5 pay, or prompt to open in WeChat when unavailable.
+- [ ] Payment success pages are informational only; actual fulfillment depends on webhook or server-side reconciliation.
+- [ ] OpenAI advanced scheduler is available but default-disabled.
+
+## Hard Technical Constraints From Audit
+
+- [ ] Browser-based third-party auth must use Authorization Code + PKCE `S256`.
+- [ ] OIDC identity primary key must be `(issuer, subject)`, not email.
+- [ ] Email equality must never auto-link accounts.
+- [ ] Bind-existing-account must require explicit local re-authentication and TOTP verification when enabled.
+- [ ] OAuth redirect URI must be fixed server config, exact-match, and never derived from user input.
+- [ ] User-supplied redirect may only choose a normalized same-origin internal route after completion.
+- [ ] WeChat canonical identity must be `unionid`; `openid` remains channel/app-scoped support data only.
+- [ ] Every payment order must snapshot the selected provider instance and reuse that exact instance for callback verification, reconciliation, refund, and audit.
+- [ ] Frontend must not receive first-party bearer tokens through callback URL fragments in the rebuilt flow.
+- [ ] Public payment result polling must not expose order data by raw `out_trade_no` alone; use authenticated lookup or signed opaque result token.
+
+## Baseline Notes
+
+- [ ] Current clean branch head when this plan was written: `721d7ab3`.
+- [ ] Baseline backend verification on clean `origin/main`: `cd backend && go test ./...` passes.
+- [ ] Baseline frontend verification on clean `origin/main`: `cd frontend && pnpm test:run` currently fails in unrelated existing suites. New work must add targeted tests and avoid claiming full frontend green until those baseline failures are addressed separately.
+- [ ] Existing migration directory currently ends at `107_*`; this rebuild reserves `108` through `111`.
+
+## Target File Map
+
+### New backend migrations
+
+- [ ] `backend/migrations/108_auth_identity_foundation_core.sql`
+- [ ] `backend/migrations/109_auth_identity_compat_backfill.sql`
+- [ ] `backend/migrations/110_pending_auth_and_provider_default_grants.sql`
+- [ ] `backend/migrations/111_payment_routing_and_scheduler_flags.sql`
+
+### New or rebuilt Ent schema
+
+- [ ] `backend/ent/schema/auth_identity.go`
+- [ ] `backend/ent/schema/auth_identity_channel.go`
+- [ ] `backend/ent/schema/pending_auth_session.go`
+- [ ] `backend/ent/schema/identity_adoption_decision.go`
+
+### New or rebuilt backend repositories/services/handlers
+
+- [ ] `backend/internal/repository/user_profile_identity_repo.go`
+- [ ] `backend/internal/repository/user_profile_identity_repo_contract_test.go`
+- [ ] `backend/internal/repository/auth_identity_migration_report.go`
+- [ ] `backend/internal/service/auth_identity_flow.go`
+- [ ] `backend/internal/service/auth_identity_flow_test.go`
+- [ ] `backend/internal/service/auth_pending_identity_service.go`
+- [ ] `backend/internal/service/auth_pending_identity_service_test.go`
+- [ ] `backend/internal/service/payment_config_service.go`
+- [ ] `backend/internal/service/payment_order.go`
+- [ ] `backend/internal/service/payment_order_lifecycle.go`
+- [ ] `backend/internal/service/payment_fulfillment.go`
+- [ ] `backend/internal/service/openai_account_scheduler.go`
+- [ ] `backend/internal/handler/auth_pending_identity_flow.go`
+- [ ] `backend/internal/handler/auth_linuxdo_oauth.go`
+- [ ] `backend/internal/handler/auth_oidc_oauth.go`
+- [ ] `backend/internal/handler/auth_wechat_oauth.go`
+- [ ] `backend/internal/handler/auth_handler.go`
+- [ ] `backend/internal/handler/user_handler.go`
+- [ ] `backend/internal/handler/payment_handler.go`
+- [ ] `backend/internal/handler/payment_webhook_handler.go`
+- [ ] `backend/internal/handler/admin/user_handler.go`
+- [ ] `backend/internal/handler/admin/setting_handler.go`
+
+### New or rebuilt frontend API/store/views/components
+
+- [ ] `frontend/src/api/auth.ts`
+- [ ] `frontend/src/api/user.ts`
+- [ ] `frontend/src/api/payment.ts`
+- [ ] `frontend/src/api/admin/settings.ts`
+- [ ] `frontend/src/api/admin/users.ts`
+- [ ] `frontend/src/stores/auth.ts`
+- [ ] `frontend/src/stores/payment.ts`
+- [ ] `frontend/src/components/auth/ThirdPartyAuthCallbackFlow.vue`
+- [ ] `frontend/src/components/auth/LinuxDoOAuthSection.vue`
+- [ ] `frontend/src/components/auth/OidcOAuthSection.vue`
+- [ ] `frontend/src/components/auth/WechatOAuthSection.vue`
+- [ ] `frontend/src/components/user/profile/ProfileAccountBindingsCard.vue`
+- [ ] `frontend/src/components/user/profile/ProfileInfoCard.vue`
+- [ ] `frontend/src/views/auth/LinuxDoCallbackView.vue`
+- [ ] `frontend/src/views/auth/OidcCallbackView.vue`
+- [ ] `frontend/src/views/auth/WechatCallbackView.vue`
+- [ ] `frontend/src/views/user/ProfileView.vue`
+- [ ] `frontend/src/views/user/PaymentView.vue`
+- [ ] `frontend/src/views/user/PaymentQRCodeView.vue`
+- [ ] `frontend/src/views/user/PaymentResultView.vue`
+
+## Phase 1: Migration And Compatibility Foundation
+
+### Task 1. Create core identity schema migration
+
+- [ ] Implement `backend/migrations/108_auth_identity_foundation_core.sql` with:
+ - `auth_identities`
+ - `auth_identity_channels`
+ - `pending_auth_sessions`
+ - `identity_adoption_decisions`
+ - `users.last_login_at`
+ - `users.last_active_at`
+ - grant-tracking columns/tables required to prevent double-award
+- [ ] Add uniqueness/index rules:
+ - one canonical identity per `(provider, provider_subject)`
+ - one channel record per `(provider, provider_channel, provider_app_id, provider_channel_subject)`
+ - one adoption decision per pending session
+- [ ] Preserve null-safe compatibility defaults so historical rows remain readable before backfill finishes.
+- [ ] Add explicit rollback blocks only where safe; never repeat the destructive pattern observed in old `112_update_pending_auth_sessions.sql`.
+
+### Task 2. Materialize historical identities before runtime
+
+- [ ] Implement `backend/migrations/109_auth_identity_compat_backfill.sql` to backfill:
+ - existing email users into `auth_identities(provider=email, provider_subject=normalized_email)`
+ - historical LinuxDo users into `auth_identities(provider=linuxdo, provider_subject=linuxdo_subject)`
+ - historical synthetic-email LinuxDo users into explicit LinuxDo identity rows by parsing legacy email mode and legacy provider metadata
+ - profile/channel rows from historical `user_external_identities`-style data when present in upgraded databases
+- [ ] Write migration report output in `backend/internal/repository/auth_identity_migration_report.go` so production can inspect unmatched rows instead of silently skipping them.
+- [ ] Set `signup_source` and provider provenance when recoverable from historical data. Do not flatten everything to `email`.
+
+### Task 3. Provider default grant and scheduler config migration
+
+- [ ] Implement `backend/migrations/110_pending_auth_and_provider_default_grants.sql` for:
+ - provider-specific initial balance/concurrency/subscription defaults
+ - grant timing flags: `on_signup`, optional `on_first_bind`
+ - email-required-on-third-party-signup flags
+ - profile avatar storage columns/settings
+- [ ] Implement `backend/migrations/111_payment_routing_and_scheduler_flags.sql` for:
+ - stable payment method to provider-instance routing
+ - admin exclusivity flags for `alipay` and `wxpay`
+ - advanced scheduler enable flag defaulting to disabled
+
+### Task 4. Generate Ent and compile migration-safe model layer
+
+- [ ] Add the schema definitions in:
+ - `backend/ent/schema/auth_identity.go`
+ - `backend/ent/schema/auth_identity_channel.go`
+ - `backend/ent/schema/pending_auth_session.go`
+ - `backend/ent/schema/identity_adoption_decision.go`
+- [ ] Run:
+ ```bash
+ cd backend
+ go generate ./ent
+ ```
+- [ ] Compile after generation:
+ ```bash
+ cd backend
+ go test ./... -run '^$'
+ ```
+- [ ] Commit checkpoint:
+ ```bash
+ git add backend/migrations backend/ent/schema backend/ent
+ git commit -m "feat: add auth identity foundation schema"
+ ```
+
+## Phase 2: Backend Identity Flow Rebuild
+
+### Task 5. Build a single repository contract for identity lookups and grants
+
+- [ ] Implement `backend/internal/repository/user_profile_identity_repo.go` with transactional helpers for:
+ - get user by canonical identity
+ - get user by channel identity
+ - create canonical + channel identity together
+ - bind identity to existing user after verified re-auth
+ - record one-time provider grant award
+ - record adoption preference decisions
+ - update `last_login_at` and `last_active_at`
+- [ ] Add repository contract coverage in `backend/internal/repository/user_profile_identity_repo_contract_test.go`.
+- [ ] Enforce dual-write for email registration/login so `users.email` and `auth_identities(provider=email, ...)` stay consistent from this phase onward.
+
+### Task 6. Rebuild transactional pending-auth service
+
+- [ ] Implement `backend/internal/service/auth_pending_identity_service.go` and tests to own these flows:
+ - create pending session from third-party callback
+ - verify local email code
+ - create new account from pending session with correct `signup_source`
+ - bind pending identity to existing account after password/TOTP re-auth
+ - apply configured provider defaults on the correct trigger only once
+ - store provider nickname/avatar candidates and user opt-in replacement decisions independently
+- [ ] Keep pending session payload normalized:
+ - provider identity fields live in typed columns/JSON structure
+ - avoid the old branch’s mixed `metadata` and `upstream_identity_payload` ambiguity
+- [ ] Do not call plain email registration helpers from this flow. The old feature branch bug where pending third-party signup fell back to `RegisterWithVerification` must not reappear.
+
+### Task 7. Rebuild provider callback adapters
+
+- [ ] Refactor these handlers to thin adapters over the shared pending-auth service:
+ - `backend/internal/handler/auth_linuxdo_oauth.go`
+ - `backend/internal/handler/auth_oidc_oauth.go`
+ - `backend/internal/handler/auth_wechat_oauth.go`
+- [ ] For OIDC:
+ - require PKCE `S256`, `state`, and `nonce`
+ - validate `iss`, `aud`, optional `azp`, `exp`, `nonce`
+ - persist canonical identity as `(issuer, sub)`
+- [ ] For WeChat:
+ - MP flow in WeChat UA
+ - Open/QR flow outside WeChat UA
+ - persist channel identity by `(channel, appid, openid)`
+ - persist canonical identity by `unionid`
+ - hard-fail when `unionid` is absent under the approved product policy
+- [ ] Replace callback URL fragment token delivery with backend session completion or one-time exchange code consumed by `frontend/src/stores/auth.ts`.
+
+### Task 8. Rebuild auth endpoints and profile binding endpoints
+
+- [ ] Implement `backend/internal/handler/auth_pending_identity_flow.go` for:
+ - fetch pending session summary
+ - submit verified email
+ - choose create-new-account or bind-existing-account
+ - submit nickname/avatar replacement choices
+- [ ] Update `backend/internal/handler/auth_handler.go` and `backend/internal/handler/user_handler.go` to expose:
+ - current bindings summary
+ - start-bind endpoints for LinuxDo/OIDC/WeChat
+ - disconnect endpoints with safety checks
+ - avatar upload/delete endpoints
+- [ ] Avatar handling requirements:
+ - allow external URL
+ - allow data URL upload
+ - compress image payload to `<=100KB`
+ - store compressed value in DB
+ - deleting custom avatar must not implicitly resurrect stale provider avatar unless the user explicitly chooses provider avatar again
+
+### Task 9. Add admin visibility and sorting
+
+- [ ] Update `backend/internal/handler/admin/user_handler.go` and supporting query/service code so admin list supports:
+ - `last_login_at`
+ - `last_active_at`
+ - sorting by both
+ - binding/provider summary columns
+- [ ] Update `backend/internal/handler/admin/setting_handler.go` and setting service code for:
+ - provider initial grant config
+ - mandatory-email-on-third-party-signup config
+ - payment source exclusivity config
+ - advanced scheduler toggle
+
+### Task 10. Backend verification checkpoint
+
+- [ ] Run targeted backend tests:
+ ```bash
+ cd backend
+ go test ./internal/repository -run 'TestUserProfileIdentity|TestAuthIdentityMigration'
+ go test ./internal/service -run 'TestAuthIdentityFlow|TestPendingAuthIdentity|TestOpenAIAccountScheduler'
+ go test ./internal/handler -run 'TestLinuxDo|TestOidc|TestWechat|TestPaymentWebhook'
+ go test ./...
+ ```
+- [ ] Commit checkpoint:
+ ```bash
+ git add backend
+ git commit -m "feat: rebuild auth identity backend flows"
+ ```
+
+## Phase 3: Frontend Third-Party Flow And Profile UX
+
+### Task 11. Rebuild callback flow UI around pending session decisions
+
+- [ ] Rebuild `frontend/src/components/auth/ThirdPartyAuthCallbackFlow.vue` so it:
+ - loads pending-session summary from backend
+ - shows provider nickname/avatar candidates
+ - lets user independently choose nickname replacement and avatar replacement
+ - handles create-new-account vs bind-existing-account
+ - enforces verified local email before completion when required
+ - handles “email already exists” by branching to bind-existing-account or change-email-and-create-new-account
+- [ ] Update:
+ - `frontend/src/views/auth/LinuxDoCallbackView.vue`
+ - `frontend/src/views/auth/OidcCallbackView.vue`
+ - `frontend/src/views/auth/WechatCallbackView.vue`
+ - `frontend/src/api/auth.ts`
+ - `frontend/src/stores/auth.ts`
+- [ ] Replace any token-fragment bootstrap with backend session completion or one-time exchange code flow.
+
+### Task 12. Rebuild profile account binding and avatar UX
+
+- [ ] Rebuild `frontend/src/components/user/profile/ProfileAccountBindingsCard.vue` to:
+ - show linked LinuxDo/OIDC/WeChat providers
+ - start bind/unbind flows
+ - show provider avatars and nicknames as reference only
+ - prevent unsafe disconnect when it would strand the account
+- [ ] Rebuild `frontend/src/components/user/profile/ProfileInfoCard.vue` and `frontend/src/views/user/ProfileView.vue` to:
+ - support avatar URL entry
+ - support data URL upload/compression preview
+ - support avatar delete
+ - clearly separate current profile nickname/avatar from provider-sourced suggested nickname/avatar
+
+### Task 13. Add frontend tests for rebuilt auth/profile flows
+
+- [ ] Add or update:
+ - `frontend/src/components/auth/__tests__/ThirdPartyAuthCallbackFlow.spec.ts`
+ - `frontend/src/components/auth/__tests__/LinuxDoCallbackView.spec.ts`
+ - `frontend/src/components/auth/__tests__/WechatCallbackView.spec.ts`
+ - `frontend/src/components/user/profile/__tests__/ProfileAccountBindingsCard.spec.ts`
+ - `frontend/src/components/user/profile/__tests__/ProfileInfoCard.spec.ts`
+- [ ] Cover:
+ - email-required branch
+ - email-conflict branch
+ - bind-existing-account with re-auth prompt
+ - nickname replacement only
+ - avatar replacement only
+ - neither replacement
+ - avatar delete after prior provider adoption
+
+## Phase 4: Payment Routing Rebuild
+
+### Task 14. Normalize payment routing backend
+
+- [ ] Rebuild `backend/internal/service/payment_config_service.go` to expose a stable method-routing contract:
+ - frontend visible methods remain `alipay` and `wxpay`
+ - admin chooses which provider instance serves each method
+ - runtime validation guarantees only one active source per visible method
+- [ ] Rebuild `backend/internal/service/payment_order.go` and `backend/internal/service/payment_order_lifecycle.go` so order creation snapshots:
+ - visible method
+ - selected provider instance id
+ - provider type
+ - provider capability mode
+- [ ] Rebuild `backend/internal/handler/payment_handler.go` for UX rules:
+ - Alipay PC: QR page
+ - Alipay mobile: direct jump
+ - WeChat PC: QR page
+ - WeChat H5 in WeChat: MP/JSAPI first, fallback to H5
+ - WeChat H5 outside WeChat: H5 or “open in WeChat” prompt when unavailable
+- [ ] Never derive canonical return URL from `Referer`; use configured or signed internal callback targets only.
+
+### Task 15. Make fulfillment and reconciliation provider-instance-safe
+
+- [ ] Rebuild `backend/internal/handler/payment_webhook_handler.go` and `backend/internal/service/payment_fulfillment.go` so:
+ - verification uses the order’s original provider instance
+ - webhook processing is idempotent by provider event id and internal order id
+ - missed webhook recovery uses server-side provider query, not frontend success return
+- [ ] Harden `frontend/src/views/user/PaymentResultView.vue` and `frontend/src/api/payment.ts` so result polling uses an authenticated order lookup or signed opaque token, not a raw public `out_trade_no` query.
+
+### Task 16. Rebuild payment frontend views
+
+- [ ] Rebuild `frontend/src/views/user/PaymentView.vue`, `frontend/src/views/user/PaymentQRCodeView.vue`, and `frontend/src/stores/payment.ts` so:
+ - only two buttons are shown to user: `支付宝` and `微信支付`
+ - frontend does not leak official-vs-EasyPay distinction
+ - route-specific copy handles QR, jump, MP, H5 fallback correctly
+- [ ] Add or update:
+ - `frontend/src/views/user/__tests__/PaymentView.spec.ts`
+ - `frontend/src/views/user/__tests__/PaymentResultView.spec.ts`
+ - backend webhook/payment routing tests
+
+### Task 17. Payment verification checkpoint
+
+- [ ] Run:
+ ```bash
+ cd backend
+ go test ./internal/service -run 'TestPayment'
+ go test ./internal/handler -run 'TestPayment'
+ cd ../frontend
+ pnpm test:run src/views/user/__tests__/PaymentView.spec.ts src/views/user/__tests__/PaymentResultView.spec.ts
+ ```
+- [ ] Commit checkpoint:
+ ```bash
+ git add backend frontend
+ git commit -m "feat: rebuild payment routing foundation"
+ ```
+
+## Phase 5: Scheduler, Rollout, And Final Compatibility Pass
+
+### Task 18. Gate advanced scheduler behind explicit config
+
+- [ ] Update `backend/internal/service/openai_account_scheduler.go` and related admin setting surfaces so:
+ - advanced scheduler remains compiled and testable
+ - default runtime state is disabled
+ - enablement is explicit through admin settings
+ - legacy scheduling behavior remains default on upgrade
+- [ ] Add targeted coverage in `backend/internal/service/openai_account_scheduler_test.go`.
+
+### Task 19. Complete compatibility and rollout safety checks
+
+- [ ] Add migration/repository tests covering:
+ - historical email-only user login after upgrade
+ - historical LinuxDo user login after upgrade
+ - historical synthetic-email LinuxDo user login after upgrade
+ - no retroactive grant replay during migration
+ - first-bind grant fires once only when enabled
+ - email identity dual-write stays consistent
+ - bind-existing-account requires password and TOTP where configured
+- [ ] Add deploy sequencing note to release docs or internal runbook:
+ 1. deploy schema and backfill release.
+ 2. inspect migration report for unmatched rows.
+ 3. deploy backend identity/payment compatibility code.
+ 4. deploy frontend callback/profile/payment UI.
+ 5. enable strict email-required signup or provider bind grants only after metrics are healthy.
+
+### Task 20. Final verification and handoff
+
+- [ ] Run final backend verification:
+ ```bash
+ cd backend
+ go test ./...
+ ```
+- [ ] Run targeted frontend verification:
+ ```bash
+ cd frontend
+ pnpm test:run \
+ src/components/auth/__tests__/ThirdPartyAuthCallbackFlow.spec.ts \
+ src/components/auth/__tests__/LinuxDoCallbackView.spec.ts \
+ src/components/auth/__tests__/WechatCallbackView.spec.ts \
+ src/components/user/profile/__tests__/ProfileAccountBindingsCard.spec.ts \
+ src/components/user/profile/__tests__/ProfileInfoCard.spec.ts \
+ src/views/user/__tests__/PaymentView.spec.ts \
+ src/views/user/__tests__/PaymentResultView.spec.ts
+ ```
+- [ ] Run focused manual smoke checks:
+ - email login with existing account
+ - LinuxDo existing-account login after migration
+ - third-party first login create-new-account path
+ - third-party first login bind-existing-account path
+ - first third-party bind with optional nickname/avatar replacement
+ - PC Alipay QR
+ - mobile Alipay jump
+ - PC WeChat QR
+ - WeChat H5 MP/JSAPI path
+ - non-WeChat H5 fallback path
+- [ ] Commit final checkpoint:
+ ```bash
+ git add docs backend frontend
+ git commit -m "feat: rebuild auth identity and payment foundation"
+ ```
+
+## Review Checklist
+
+- [ ] No flow still relies on provider email equality for account linking.
+- [ ] No flow still creates third-party users through plain email registration helpers.
+- [ ] No callback still returns first-party bearer tokens in URL fragments.
+- [ ] No payment result view trusts provider return page as authoritative fulfillment.
+- [ ] No webhook verification path selects provider credentials from “currently active config” instead of the order snapshot.
+- [ ] Existing email users and historical LinuxDo users are covered by migration tests.
+- [ ] Avatar adoption and deletion semantics are explicit and reversible.
+- [ ] Grant timing is source-aware and one-time only.
+
--
GitLab
From 584ded2182e2b1225a67e55d8a5485bbbdc35658 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Mon, 20 Apr 2026 14:41:12 +0800
Subject: [PATCH 030/261] docs: harden auth identity payment design
---
...-04-20-auth-identity-payment-foundation.md | 87 ++++++++--
...auth-identity-payment-foundation-design.md | 149 +++++++++++++++---
2 files changed, 199 insertions(+), 37 deletions(-)
diff --git a/docs/superpowers/plans/2026-04-20-auth-identity-payment-foundation.md b/docs/superpowers/plans/2026-04-20-auth-identity-payment-foundation.md
index e8fde9c0..2d44e058 100644
--- a/docs/superpowers/plans/2026-04-20-auth-identity-payment-foundation.md
+++ b/docs/superpowers/plans/2026-04-20-auth-identity-payment-foundation.md
@@ -2,7 +2,7 @@
> **For agentic workers:** REQUIRED SUB-SKILL: Use `superpowers:subagent-driven-development` (recommended) or `superpowers:executing-plans` to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
-**Goal:** Rebuild the auth identity, profile binding, payment routing, and OpenAI advanced scheduler foundation on top of a clean `origin/main` branch while preserving historical compatibility for existing email users and historical LinuxDo users.
+**Goal:** Rebuild the auth identity, profile binding, payment routing, and OpenAI advanced scheduler foundation on top of a clean `origin/main` branch while preserving historical compatibility for existing email users, existing LinuxDo users, historical LinuxDo/WeChat/OIDC synthetic-email users, and historical WeChat `openid`-only records.
**Architecture:** A unified identity foundation centered on durable provider subjects (`email`, `linuxdo`, `oidc`, `wechat`) and transactional pending-auth sessions; backend-owned payment source routing behind stable frontend methods (`alipay`, `wxpay`); compatibility-first migration/backfill before feature enablement.
**Tech Stack:** Go, Gin, Ent, PostgreSQL, Redis, Vue 3, Pinia, TypeScript, Vitest, pnpm.
@@ -10,8 +10,9 @@
## Non-Negotiable Product Rules
-- [ ] Preserve login continuity for existing email users and historical LinuxDo users.
-- [ ] During migration, backfill historical LinuxDo synthetic-email users into explicit LinuxDo identities before first post-upgrade login.
+- [ ] Preserve login continuity for existing email users, existing LinuxDo users, and historically migrated third-party users.
+- [ ] During migration, backfill historical LinuxDo/WeChat/OIDC synthetic-email users into explicit third-party identities before first post-upgrade login whenever deterministic recovery is possible.
+- [ ] During migration, surface historical WeChat `openid`-only records through explicit migration reports and remediation rules; do not silently reinterpret them as valid canonical identities.
- [ ] Keep existing email login and add third-party login/bind for `linuxdo`, `oidc`, and `wechat`.
- [ ] On first third-party login:
- identity exists: direct login.
@@ -37,6 +38,11 @@
- non-WeChat browser uses Open/QR login.
- canonical identity uses `unionid`.
- when `unionid` is unavailable, fail the login/bind flow under the approved option-1 policy.
+- [ ] OIDC rules:
+ - browser authorization-code flow always uses PKCE `S256`.
+ - discovery issuer and ID token `iss` must match exactly.
+ - `userinfo.sub` must match ID token `sub` when UserInfo is used.
+ - upstream `email_verified` does not satisfy local email verification.
- [ ] Payment UI rules:
- user-facing methods stay `支付宝` and `微信支付`.
- backend decides whether each method routes to official provider instance or EasyPay.
@@ -49,20 +55,26 @@
- WeChat H5: MP/JSAPI first, fallback to H5 pay.
- non-WeChat H5: H5 pay, or prompt to open in WeChat when unavailable.
- [ ] Payment success pages are informational only; actual fulfillment depends on webhook or server-side reconciliation.
+- [ ] WeChat in-app payment requiring `openid` must use a dedicated server-backed payment OAuth resume flow rather than frontend-only recovery state.
- [ ] OpenAI advanced scheduler is available but default-disabled.
## Hard Technical Constraints From Audit
- [ ] Browser-based third-party auth must use Authorization Code + PKCE `S256`.
+- [ ] PKCE must not be admin-configurable off for browser authorization-code providers.
- [ ] OIDC identity primary key must be `(issuer, subject)`, not email.
- [ ] Email equality must never auto-link accounts.
- [ ] Bind-existing-account must require explicit local re-authentication and TOTP verification when enabled.
+- [ ] Bind-current-user must originate from an already-authenticated local user and preserve explicit bind intent across callback completion.
- [ ] OAuth redirect URI must be fixed server config, exact-match, and never derived from user input.
- [ ] User-supplied redirect may only choose a normalized same-origin internal route after completion.
- [ ] WeChat canonical identity must be `unionid`; `openid` remains channel/app-scoped support data only.
-- [ ] Every payment order must snapshot the selected provider instance and reuse that exact instance for callback verification, reconciliation, refund, and audit.
+- [ ] Every canonical identity uniqueness rule must include provider namespace (`provider_key`) consistently.
+- [ ] Callback completion must use backend session completion or a one-time opaque exchange code that is short-lived, one-time, browser-session-bound, `POST`-redeemed, and unusable as a bearer token.
+- [ ] Every payment order must snapshot the selected provider instance plus the order-time verification inputs required for callback verification, reconciliation, refund, and audit.
- [ ] Frontend must not receive first-party bearer tokens through callback URL fragments in the rebuilt flow.
- [ ] Public payment result polling must not expose order data by raw `out_trade_no` alone; use authenticated lookup or signed opaque result token.
+- [ ] WeChat Pay webhook handling must verify signature, decrypt payload, and compare `appid`, `mchid`, `out_trade_no`, `amount`, `currency`, and provider trade state against the order snapshot before fulfillment.
## Baseline Notes
@@ -100,6 +112,8 @@
- [ ] `backend/internal/service/payment_order.go`
- [ ] `backend/internal/service/payment_order_lifecycle.go`
- [ ] `backend/internal/service/payment_fulfillment.go`
+- [ ] `backend/internal/service/payment_resume_service.go`
+- [ ] `backend/internal/service/payment_resume_service_test.go`
- [ ] `backend/internal/service/openai_account_scheduler.go`
- [ ] `backend/internal/handler/auth_pending_identity_flow.go`
- [ ] `backend/internal/handler/auth_linuxdo_oauth.go`
@@ -148,9 +162,10 @@
- `users.last_active_at`
- grant-tracking columns/tables required to prevent double-award
- [ ] Add uniqueness/index rules:
- - one canonical identity per `(provider, provider_subject)`
+ - one canonical identity per `(provider, provider_key, provider_subject)`
- one channel record per `(provider, provider_channel, provider_app_id, provider_channel_subject)`
- one adoption decision per pending session
+- [ ] Model `pending_auth_sessions` so immutable upstream claims and mutable local flow state are stored separately; do not reintroduce a mixed `metadata` catch-all.
- [ ] Preserve null-safe compatibility defaults so historical rows remain readable before backfill finishes.
- [ ] Add explicit rollback blocks only where safe; never repeat the destructive pattern observed in old `112_update_pending_auth_sessions.sql`.
@@ -160,8 +175,10 @@
- existing email users into `auth_identities(provider=email, provider_subject=normalized_email)`
- historical LinuxDo users into `auth_identities(provider=linuxdo, provider_subject=linuxdo_subject)`
- historical synthetic-email LinuxDo users into explicit LinuxDo identity rows by parsing legacy email mode and legacy provider metadata
+ - historical synthetic-email WeChat users into explicit WeChat identities where `unionid` or equivalent deterministic provider identity is recoverable
+ - historical synthetic-email OIDC users into explicit OIDC identities where deterministic provider identity is recoverable
- profile/channel rows from historical `user_external_identities`-style data when present in upgraded databases
-- [ ] Write migration report output in `backend/internal/repository/auth_identity_migration_report.go` so production can inspect unmatched rows instead of silently skipping them.
+- [ ] Write migration report output in `backend/internal/repository/auth_identity_migration_report.go` so production can inspect unmatched rows, `openid`-only WeChat rows, and non-deterministic synthetic-email rows instead of silently skipping them.
- [ ] Set `signup_source` and provider provenance when recoverable from historical data. Do not flatten everything to `email`.
### Task 3. Provider default grant and scheduler config migration
@@ -173,6 +190,7 @@
- profile avatar storage columns/settings
- [ ] Implement `backend/migrations/111_payment_routing_and_scheduler_flags.sql` for:
- stable payment method to provider-instance routing
+ - visible-method normalization from historical `supported_types`, `payment_mode`, and legacy aliases such as `wxpay_direct`
- admin exclusivity flags for `alipay` and `wxpay`
- advanced scheduler enable flag defaulting to disabled
@@ -213,6 +231,7 @@
- update `last_login_at` and `last_active_at`
- [ ] Add repository contract coverage in `backend/internal/repository/user_profile_identity_repo_contract_test.go`.
- [ ] Enforce dual-write for email registration/login so `users.email` and `auth_identities(provider=email, ...)` stay consistent from this phase onward.
+- [ ] Add repository coverage proving `last_login_at` and `last_active_at` use the required field names and are not silently replaced by derived `last_used_at` logic.
### Task 6. Rebuild transactional pending-auth service
@@ -223,8 +242,15 @@
- bind pending identity to existing account after password/TOTP re-auth
- apply configured provider defaults on the correct trigger only once
- store provider nickname/avatar candidates and user opt-in replacement decisions independently
+- [ ] Implement callback completion so pending auth can finish through backend session completion or a one-time exchange code:
+ - short TTL
+ - one-time use
+ - browser-session binding
+ - `POST` redemption only
+ - safe mixed-version bridge to legacy pending-token aliases during rollout
- [ ] Keep pending session payload normalized:
- provider identity fields live in typed columns/JSON structure
+ - mutable local progression lives separately from immutable upstream claims
- avoid the old branch’s mixed `metadata` and `upstream_identity_payload` ambiguity
- [ ] Do not call plain email registration helpers from this flow. The old feature branch bug where pending third-party signup fell back to `RegisterWithVerification` must not reappear.
@@ -236,11 +262,13 @@
- `backend/internal/handler/auth_wechat_oauth.go`
- [ ] For OIDC:
- require PKCE `S256`, `state`, and `nonce`
- - validate `iss`, `aud`, optional `azp`, `exp`, `nonce`
+ - validate discovery issuer, `iss`, `aud`, optional `azp`, `exp`, and `nonce`
+ - verify `userinfo.sub == id_token.sub` when UserInfo is used
- persist canonical identity as `(issuer, sub)`
- [ ] For WeChat:
- MP flow in WeChat UA
- Open/QR flow outside WeChat UA
+ - website login uses authorization-code flow and persists channel/app binding
- persist channel identity by `(channel, appid, openid)`
- persist canonical identity by `unionid`
- hard-fail when `unionid` is absent under the approved product policy
@@ -253,6 +281,10 @@
- submit verified email
- choose create-new-account or bind-existing-account
- submit nickname/avatar replacement choices
+- [ ] Make bind-existing-account and bind-current-user flows explicit:
+ - no automatic linking on matching email
+ - fresh password/TOTP proof is scoped to the intended target account only
+ - no automatic metadata merge beyond explicitly selected nickname/avatar adoption
- [ ] Update `backend/internal/handler/auth_handler.go` and `backend/internal/handler/user_handler.go` to expose:
- current bindings summary
- start-bind endpoints for LinuxDo/OIDC/WeChat
@@ -312,6 +344,7 @@
- `frontend/src/api/auth.ts`
- `frontend/src/stores/auth.ts`
- [ ] Replace any token-fragment bootstrap with backend session completion or one-time exchange code flow.
+- [ ] During rollout, keep temporary compatibility readers for legacy pending-token aliases behind a bounded bridge contract and explicit removal step.
### Task 12. Rebuild profile account binding and avatar UX
@@ -351,11 +384,17 @@
- frontend visible methods remain `alipay` and `wxpay`
- admin chooses which provider instance serves each method
- runtime validation guarantees only one active source per visible method
+- [ ] Add migration logic and tests to normalize historical provider-instance config:
+ - `supported_types`
+ - `payment_mode`
+ - legacy aliases such as `wxpay_direct`
+ - historical limit config
- [ ] Rebuild `backend/internal/service/payment_order.go` and `backend/internal/service/payment_order_lifecycle.go` so order creation snapshots:
- visible method
- selected provider instance id
- provider type
- provider capability mode
+ - verification-critical provider fields needed for later callback/query/refund validation
- [ ] Rebuild `backend/internal/handler/payment_handler.go` for UX rules:
- Alipay PC: QR page
- Alipay mobile: direct jump
@@ -363,6 +402,11 @@
- WeChat H5 in WeChat: MP/JSAPI first, fallback to H5
- WeChat H5 outside WeChat: H5 or “open in WeChat” prompt when unavailable
- [ ] Never derive canonical return URL from `Referer`; use configured or signed internal callback targets only.
+- [ ] Implement `backend/internal/service/payment_resume_service.go` so WeChat in-app payment OAuth resume is server-backed rather than localStorage-backed:
+ - create `oauth_required` resume context
+ - persist amount/order_type/plan_id/visible method/redirect/state
+ - redeem callback into same-origin internal resume target
+ - expire and consume resume context safely
### Task 15. Make fulfillment and reconciliation provider-instance-safe
@@ -370,6 +414,12 @@
- verification uses the order’s original provider instance
- webhook processing is idempotent by provider event id and internal order id
- missed webhook recovery uses server-side provider query, not frontend success return
+- [ ] For WeChat Pay specifically, enforce:
+ - fixed HTTPS `notify_url` with no query params
+ - no dependency on user login state
+ - signature verification before decrypt
+ - APIv3 decrypt before business parsing
+ - comparison of `appid`, `mchid`, `out_trade_no`, `amount`, `currency`, and trade state against the order snapshot
- [ ] Harden `frontend/src/views/user/PaymentResultView.vue` and `frontend/src/api/payment.ts` so result polling uses an authenticated order lookup or signed opaque token, not a raw public `out_trade_no` query.
### Task 16. Rebuild payment frontend views
@@ -378,6 +428,10 @@
- only two buttons are shown to user: `支付宝` and `微信支付`
- frontend does not leak official-vs-EasyPay distinction
- route-specific copy handles QR, jump, MP, H5 fallback correctly
+- [ ] Rebuild WeChat in-app payment resume UX around the server-backed resume context:
+ - handle `oauth_required`
+ - continue from same-origin resume target
+ - avoid long-lived localStorage as the source of truth
- [ ] Add or update:
- `frontend/src/views/user/__tests__/PaymentView.spec.ts`
- `frontend/src/views/user/__tests__/PaymentResultView.spec.ts`
@@ -416,16 +470,22 @@
- historical email-only user login after upgrade
- historical LinuxDo user login after upgrade
- historical synthetic-email LinuxDo user login after upgrade
+ - historical synthetic-email WeChat user login after upgrade
+ - historical synthetic-email OIDC user login after upgrade
+ - historical WeChat `openid`-only rows are reported or explicitly remediated
- no retroactive grant replay during migration
- first-bind grant fires once only when enabled
- email identity dual-write stays consistent
- bind-existing-account requires password and TOTP where configured
+ - mixed-version callback token bridge works during rollout and is removable afterward
+ - historical payment config is normalized into visible-method routing without refund/query regression
- [ ] Add deploy sequencing note to release docs or internal runbook:
1. deploy schema and backfill release.
2. inspect migration report for unmatched rows.
- 3. deploy backend identity/payment compatibility code.
- 4. deploy frontend callback/profile/payment UI.
- 5. enable strict email-required signup or provider bind grants only after metrics are healthy.
+ 3. deploy backend identity/payment compatibility code with exchange bridge and legacy token aliases still enabled.
+ 4. deploy frontend callback/profile/payment UI using session completion, exchange code, and server-backed WeChat payment resume.
+ 5. remove legacy callback/token parsing after mixed-version window closes.
+ 6. enable strict email-required signup or provider bind grants only after metrics are healthy.
### Task 20. Final verification and handoff
@@ -449,6 +509,8 @@
- [ ] Run focused manual smoke checks:
- email login with existing account
- LinuxDo existing-account login after migration
+ - WeChat synthetic-email account login after migration
+ - OIDC synthetic-email account login after migration
- third-party first login create-new-account path
- third-party first login bind-existing-account path
- first third-party bind with optional nickname/avatar replacement
@@ -456,6 +518,7 @@
- mobile Alipay jump
- PC WeChat QR
- WeChat H5 MP/JSAPI path
+ - WeChat in-app OAuth resume path
- non-WeChat H5 fallback path
- [ ] Commit final checkpoint:
```bash
@@ -468,9 +531,9 @@
- [ ] No flow still relies on provider email equality for account linking.
- [ ] No flow still creates third-party users through plain email registration helpers.
- [ ] No callback still returns first-party bearer tokens in URL fragments.
+- [ ] No callback completion path can be replayed as a bearer token substitute.
- [ ] No payment result view trusts provider return page as authoritative fulfillment.
- [ ] No webhook verification path selects provider credentials from “currently active config” instead of the order snapshot.
-- [ ] Existing email users and historical LinuxDo users are covered by migration tests.
+- [ ] Existing email users, historical LinuxDo/WeChat/OIDC users, and `openid`-only WeChat remediation cases are covered by migration tests.
- [ ] Avatar adoption and deletion semantics are explicit and reversible.
- [ ] Grant timing is source-aware and one-time only.
-
diff --git a/docs/superpowers/specs/2026-04-20-auth-identity-payment-foundation-design.md b/docs/superpowers/specs/2026-04-20-auth-identity-payment-foundation-design.md
index bfa250d9..790861b7 100644
--- a/docs/superpowers/specs/2026-04-20-auth-identity-payment-foundation-design.md
+++ b/docs/superpowers/specs/2026-04-20-auth-identity-payment-foundation-design.md
@@ -20,10 +20,10 @@ This design includes:
- Profile binding management and avatar upload/delete
- Source-based initial grants for balance, concurrency, and subscriptions
- User management support for `last_login_at` and `last_active_at` sorting
-- Unified payment display methods (`alipay`, `wechat`) mapped to a single active backend source each
+- Unified payment display methods (`alipay`, `wxpay`) mapped to a single active backend source each
- Alipay and WeChat UX routing rules across PC, mobile, H5, and WeChat environments
- Admin settings for auth providers, source defaults, payment sources, and OpenAI advanced scheduling
-- Incremental migration and compatibility for existing email users and historical LinuxDo synthetic-email users
+- Incremental migration and compatibility for existing email users, existing LinuxDo users, historical LinuxDo/WeChat/OIDC synthetic-email users, and historical WeChat `openid`-only identity records
This design does not treat unrelated upstream merges, docs churn, or license changes from the old branch as required scope.
@@ -32,9 +32,11 @@ This design does not treat unrelated upstream merges, docs churn, or license cha
### Auth and identity
- Existing email users remain valid and continue to log in with no manual action.
+- Existing LinuxDo, OIDC, and WeChat users represented by historical third-party or synthetic-email data must remain recoverable during migration.
- Third-party first login behavior:
- Existing bound identity: direct login
- Missing identity: start first-login flow
+- Browser-based third-party authorization-code login always uses PKCE `S256`; this is not an admin-toggleable feature.
- If `force_email_on_third_party_signup` is disabled, a first-login user may create an account without binding an email.
- If `force_email_on_third_party_signup` is enabled, the user must provide an email.
- If the provided and verified email already exists:
@@ -43,11 +45,25 @@ This design does not treat unrelated upstream merges, docs churn, or license cha
- allow "change email and continue registration"
- do not allow bypassing the email requirement
- Upstream provider email verification is not trusted as a local bound email.
+- Matching upstream email must never auto-link to an existing local account.
+- Linking to an existing local account is allowed only when:
+ - the user explicitly chooses that target account
+ - the target account passes fresh local re-authentication
+ - required TOTP verification succeeds
+- New third-party bind initiated from profile must start from an already logged-in local account and preserve explicit bind intent end-to-end.
+- `redirect_to` may only represent a normalized same-origin internal route. It must never contain a third-party URL and must never be derived from `Referer`.
+- OIDC validation rules:
+ - canonical identity key is `issuer + sub`
+ - discovery issuer and ID token `iss` must match exactly
+ - `userinfo.sub` must match ID token `sub` when UserInfo is used
+ - upstream `email_verified` may improve UX copy but does not satisfy local email-binding requirements
- WeChat login chooses channel by environment:
- in WeChat environment: `mp`
- outside WeChat: `open`
- WeChat primary identity key is `unionid`.
- If a WeChat login/bind flow cannot produce `unionid`, the flow fails and no fallback `openid` identity is created.
+- Historical WeChat records that only contain `openid` are treated as migration-remediation cases, not as a valid long-term canonical identity model.
+- WeChat website login uses authorization code flow, random `state`, and the provider channel/app binding must be persisted alongside the resolved identity.
### Profile adoption
@@ -85,7 +101,7 @@ This design does not treat unrelated upstream merges, docs churn, or license cha
- Frontend shows only two display methods:
- `alipay`
- - `wechat`
+ - `wxpay`
- Users never choose between official providers and EasyPay explicitly.
- Backend allows only one active source per display method at a time.
- Alipay UX:
@@ -96,6 +112,10 @@ This design does not treat unrelated upstream merges, docs churn, or license cha
- non-WeChat H5: prefer H5 pay; if unavailable, tell the user to open in WeChat
- WeChat environment: prefer MP/JSAPI pay; if unavailable, fall back to H5 pay
- Payment success is confirmed by backend order state, webhook, and/or query, not only frontend return.
+- Frontend-visible labels remain `支付宝` and `微信支付`, while internal visible-method identifiers remain `alipay` and `wxpay`.
+- Public result pages must not verify order state by exposing raw `out_trade_no`; they use authenticated lookup or a signed opaque result token instead.
+- Payment callback or return URLs must be fixed same-origin internal targets. They must not be inferred from `Referer`.
+- WeChat payment webhook handling must use a fixed HTTPS `notify_url` with no query parameters and must not depend on user login state.
### OpenAI advanced scheduling
@@ -105,9 +125,9 @@ This design does not treat unrelated upstream merges, docs churn, or license cha
## Architecture
-Keep `users` as the account owner table and move login identities, channel mappings, pending auth state, and first-bind grant idempotency into dedicated tables and services. Keep email login working while progressively introducing unified identity reads and writes.
+Keep `users` as the account owner table and move login identities, channel mappings, pending auth state, callback completion state, and first-bind grant idempotency into dedicated tables and services. Keep email login working while progressively introducing unified identity reads and writes.
-Payment uses a similar split between user-visible display methods and backend provider sources. Frontend works only with stable display methods while backend resolves to the currently active source and capability matrix.
+Payment uses a similar split between user-visible display methods and backend provider sources. Frontend works only with stable display methods while backend resolves to the currently active source and capability matrix, and stores enough order-time snapshot data to survive later provider-config changes.
Compatibility is a first-class concern: migrations are additive, reads are compatibility-aware, and rollout must tolerate existing `main` data and short-lived frontend/backend version skew.
@@ -148,9 +168,9 @@ Uniqueness:
Rules:
- email identity uses canonicalized local email
-- LinuxDo uses stable provider subject
-- OIDC uses stable issuer + subject
-- WeChat uses `unionid` as canonical subject
+- LinuxDo uses stable provider subject under the configured provider namespace
+- OIDC uses stable issuer + subject, with issuer namespace represented consistently through `provider_key` and `issuer`
+- WeChat uses `unionid` as canonical subject under the configured Open Platform namespace
### `auth_identity_channels`
@@ -189,9 +209,12 @@ Fields:
- `target_user_id`
- `redirect_to`
- `resolved_email`
-- `pending_password_hash`
-- `upstream_identity_payload`
-- `metadata`
+- `registration_password_hash`
+- `upstream_identity_claims`
+- `local_flow_state`
+- `browser_session_key`
+- `completion_code_hash`
+- `completion_code_expires_at`
- `email_verified_at`
- `password_verified_at`
- `totp_verified_at`
@@ -205,19 +228,33 @@ Responsibilities:
- persist nickname/avatar suggestions
- persist explicit adoption decisions
- survive navigation between auth pages
+- support mixed-version rollout through short-lived legacy token aliases when required
+
+Security rules:
+
+- callback completion uses backend session completion or a one-time exchange code
+- exchange codes are short-lived, one-time, bound to browser session and pending session, and redeemed via `POST`
+- exchange codes must not behave as bearer tokens and must not be logged, stored in URL fragments, or reused after redemption
+- `local_flow_state` stores mutable local progression only; immutable upstream claims remain in `upstream_identity_claims`
### `identity_adoption_decisions`
-Persists user adoption preference for a specific identity.
+Persists user adoption preference collected during a pending-auth flow and resolved onto the bound identity.
Fields:
+- `pending_auth_session_id`
- `identity_id`
- `adopt_display_name`
- `adopt_avatar`
- `decided_at`
- timestamps
+Rules:
+
+- one adoption-decision row exists per pending session
+- `identity_id` is filled once final account creation or bind succeeds
+
### `user_avatars`
Stores the currently effective custom avatar.
@@ -265,6 +302,7 @@ WeChat-specific rule:
- `openid` never becomes the primary stored identity key
- if only `openid` is available, login/bind fails with a configuration/identity error
+- historical `openid`-only records must be reported and either remediated during migration or explicitly blocked from silent auto-upgrade
## Core Flows
@@ -285,6 +323,7 @@ WeChat-specific rule:
- Create `pending_auth_session`
- Frontend callback flow decides next action
+- Pending session creation stores immutable upstream claims separately from mutable local progress fields
Branches:
@@ -303,6 +342,7 @@ On new account creation:
- create `users` row
- create canonical third-party identity
+- create or update canonical email identity when local email binding succeeds
- apply source signup grants
- apply adoption choices if selected
@@ -310,21 +350,34 @@ On new account creation:
- current user starts bind flow
- callback resolves to `bind_current_user`
+- bind intent is tied to the initiating local user session and cannot be re-targeted by email match
- bind canonical identity to current user
- if configured and first bind for that provider, apply first-bind grants
- present nickname/avatar replacement choice
### Bind existing account during first-login flow
+- user explicitly selects bind-existing-account
- verify password for existing account
- if account requires TOTP, verify TOTP
- bind canonical identity to target account
- optionally apply first-bind grants
- present nickname/avatar replacement choice
+- no automatic profile or metadata merge occurs beyond explicitly selected nickname/avatar replacement
+
+### Callback completion and exchange flow
+
+- third-party callback never returns first-party bearer tokens in URL fragments
+- callback completion uses either:
+ - backend session completion tied to the initiating browser session
+ - one-time opaque exchange code redeemed by `POST`
+- mixed-version rollout may temporarily emit legacy pending token aliases in addition to the new completion path
+- legacy alias support is transitional and bounded to rollout windows only
### WeChat login and channel mapping
- environment chooses `mp` or `open`
+- website login uses authorization-code flow with provider-configured app/channel binding
- callback must resolve to `unionid`
- channel `openid` is optionally recorded in `auth_identity_channels`
- failure to obtain `unionid` aborts flow
@@ -344,7 +397,7 @@ On new account creation:
### User-visible methods
- `alipay`
-- `wechat`
+- `wxpay`
### Backend source abstraction
@@ -357,6 +410,12 @@ Each display method maps to exactly one active configured backend source:
Frontend submits display method only. Backend resolves display method to active source and capability set.
+### Legacy payment-config normalization
+
+- existing provider-instance `supported_types`, legacy aliases such as `wxpay_direct`, and per-type limit structures are migrated into the visible-method model
+- migration preserves historical payment capability and refund semantics
+- the system keeps one normalized visible-method mapping per provider instance for rollout and audit
+
### Alipay routing
- PC: create QR-oriented result and show QR in page
@@ -372,11 +431,25 @@ Frontend submits display method only. Backend resolves display method to active
- prefer MP/JSAPI
- if unavailable, fall back to H5 pay
+### WeChat payment OAuth recovery
+
+- if WeChat in-app payment requires `openid` and the current request does not already hold it, backend returns an `oauth_required` response instead of guessing
+- backend creates a server-backed payment-resume context containing:
+ - target visible method
+ - amount/order type/plan context
+ - redirect target
+ - anti-replay state
+- backend redirects through a dedicated WeChat payment OAuth start endpoint
+- callback exchanges the provider code server-side, stores `openid` in the payment-resume context, and returns a same-origin internal resume target
+- frontend resumes the original order flow through the resume context instead of trusting raw callback query state or long-lived local storage
+
### Payment completion
- frontend return restores context and UI state
- backend order state remains source of truth
- webhook and/or order query remain authoritative for fulfillment
+- order fulfillment validates webhook or query payload against order-time snapshot data including provider instance, merchant identifiers, amount, currency, and provider order references
+- result pages use authenticated lookup or signed opaque result tokens, never raw public `out_trade_no`
## Admin Configuration Model
@@ -420,6 +493,9 @@ Compatibility is mandatory, especially for:
- existing email users
- existing LinuxDo users
- historical LinuxDo synthetic-email accounts
+- historical WeChat synthetic-email accounts
+- historical OIDC synthetic-email accounts
+- historical WeChat `openid`-only records created by older branches
### Additive migrations
@@ -431,6 +507,8 @@ Compatibility is mandatory, especially for:
- backfill canonical `email` identities for valid existing email users
- backfill canonical `linuxdo` identities during migration for historical synthetic-email LinuxDo users
+- backfill canonical `wechat` and `oidc` identities when historical synthetic-email or `user_external_identities` data allows deterministic reconstruction
+- emit migration reports for historical WeChat `openid`-only records that cannot be safely promoted to canonical `unionid`
- backfill must be idempotent and repeatable
### Compatibility reads
@@ -438,7 +516,7 @@ Compatibility is mandatory, especially for:
During rollout:
- read new identity model first
-- where necessary, retain compatibility logic for existing email and historical LinuxDo synthetic-email recognition
+- where necessary, retain compatibility logic for existing email and historical LinuxDo/WeChat/OIDC synthetic-email recognition
### Grant idempotency
@@ -453,17 +531,20 @@ Retain transitional support for legacy/new request and response shapes where nee
- `pending_oauth_token`
- old callback parsing expectations
- historical profile field mappings
+- legacy callback fragment readers during the bounded rollout window
### Settings and payment compatibility
- preserve existing payment configs and order semantics from `main`
- add new settings incrementally
- avoid rewriting the entire settings schema in one cutover
+- preserve legacy provider-instance capabilities by explicitly mapping historical `supported_types`, `payment_mode`, and limit config into normalized visible-method routing
### Rolling upgrade tolerance
- do not assume simultaneous frontend/backend deployment
- new backend must tolerate short-lived older frontend request shapes
+- rollout must define the deployment order and removal point for legacy callback token parsing and legacy payment resume parsing
## Testing Strategy
@@ -509,9 +590,13 @@ Retain transitional support for legacy/new request and response shapes where nee
- existing email users
- historical LinuxDo synthetic-email users
+- historical WeChat synthetic-email users
+- historical OIDC synthetic-email users
+- historical WeChat `openid`-only records reported or remediated correctly
- historical payment config
- legacy auth payload field names
- historical payment result handling
+- mixed-version callback token bridge behavior
## Implementation Phases
@@ -528,9 +613,12 @@ Retain transitional support for legacy/new request and response shapes where nee
Implementation must follow current primary-source guidance:
- OAuth 2.0 Security BCP (RFC 9700): strict redirect handling, state protection, mix-up resistant design
-- PKCE (RFC 7636): use on authorization code flows where applicable
+- PKCE (RFC 7636): require `S256` on browser authorization-code flows
- OpenID Connect Core: stable issuer/subject handling for OIDC identities
- Account linking best practice: require explicit user confirmation or re-authentication before linking to existing accounts
+- WeChat UnionID and website-login guidance: treat `unionid` as canonical cross-channel subject and persist channel/app binding with website login responses
+- WeChat Pay webhook guidance: verify signatures, decrypt payloads, and confirm merchant/order/amount fields against order-time state before fulfillment
+- Payment success-page guidance: custom success pages are informational and must not be the only fulfillment trigger
References:
@@ -538,6 +626,10 @@ References:
- RFC 7636:
- OpenID Connect Core 1.0:
- Auth0 account linking guidance:
+- WeChat UnionID guidance:
+- WeChat website login guidance:
+- WeChat Pay callback/signature guidance:
+- Stripe Checkout fulfillment guidance:
## Audit Synthesis
@@ -553,7 +645,7 @@ The clean rebuild direction is not to copy either existing branch directly.
- `personal-dev-branch` has the better real-world closure:
- LinuxDo and WeChat callback flows are more operationally complete
- profile binding and avatar UX is more complete
- - historical synthetic-email users are recognized and recovered in live flows
+ - historical synthetic-email users across multiple providers are recognized and recovered in live flows
- WeChat payment OAuth and recovery behavior is more complete
- Primary-source guidance supplies hard constraints for OAuth/OIDC, account linking, WeChat identity handling, and payment completion semantics.
@@ -584,6 +676,7 @@ Keep these operational flow ideas from `personal-dev-branch`:
- profile bindings UX and “cannot disconnect last usable login method” rule
- separate WeChat login OAuth and WeChat payment OAuth entry points
- historical synthetic-email recognition logic as a migration bridge
+- explicit WeChat payment OAuth recovery protocol as a product requirement, but reimplemented with server-backed resume state
### Adapt
@@ -595,6 +688,7 @@ These areas must be reimplemented with the same intent but stricter boundaries:
- WeChat payment recovery state must move from frontend-only storage to server-backed continuation state
- avatar adoption fetches must be security-hardened and failure-visible
- pending-auth payload modeling must clearly separate immutable upstream payload from mutable local metadata
+- callback completion must use a real exchange/session model instead of fragment-delivered bearer tokens
- profile binding/avatar DTOs must be simplified to one authoritative backend contract instead of sprawling frontend fallback parsing
- admin settings should preserve capability while reducing duplicated or transitional config branches
@@ -614,7 +708,7 @@ The audit and source review establish these hard constraints:
### Auth
-- all authorization-code providers use PKCE where applicable
+- all browser authorization-code providers use PKCE `S256` and do not expose an admin-off switch
- callback handling uses strict `redirect_uri` discipline and state validation
- OIDC identity key is `issuer + sub`
- existing-account linking after email conflict must require explicit user action plus local-account verification
@@ -624,7 +718,8 @@ The audit and source review establish these hard constraints:
- existing email users must continue to work with no manual intervention
- existing LinuxDo users must not split into duplicate accounts
-- historical LinuxDo synthetic-email users must be backfilled into canonical LinuxDo identities during migration, not only lazily on next login
+- historical LinuxDo/WeChat/OIDC synthetic-email users must be backfilled into canonical identities during migration when deterministic recovery is possible
+- historical WeChat `openid`-only records must be surfaced through migration reporting and explicit remediation rules
- migration backfills must not trigger signup or first-bind grants
- legacy `pending_auth_token` and `pending_oauth_token` contracts must remain accepted during rollout
- legacy auth/public setting aliases needed by older frontend builds must remain available during rollout
@@ -634,7 +729,9 @@ The audit and source review establish these hard constraints:
- frontend return pages do not determine final payment success
- backend order state, webhook processing, and/or provider status query remain authoritative
-- each visible method (`alipay`, `wechat`) may have only one active backend source at a time
+- each visible method (`alipay`, `wxpay`) may have only one active backend source at a time
+- public result pages must not expose raw `out_trade_no` lookup
+- WeChat Pay callback handling must verify signature, decrypt payload, and compare order fields against order-time snapshot data
## Known Risks To Eliminate In Implementation
@@ -649,6 +746,7 @@ These are specifically observed problems in the existing branches that the clean
- WeChat payment recovery in `personal-dev-branch` is frontend-local and not robust across tabs or concurrent attempts
- the existing pending-auth migration update is too destructive to reuse unchanged in a safer rollout
- historical provider provenance should not be permanently flattened to `signup_source = email`
+- design/plan drift can reintroduce ambiguous identity uniqueness or ambiguous adoption-decision ownership if not aligned before implementation
## Rollout Gates
@@ -656,9 +754,10 @@ The rebuild is not ready for rollout until all of these are satisfied:
1. Identity schema and migration chain are linearized and production-safe.
2. Email identity backfill is complete and idempotent.
-3. Historical LinuxDo synthetic-email backfill to canonical LinuxDo identity is complete and idempotent.
-4. `signup_source` backfill is accurate for known historical provider-created users.
-5. Dual token acceptance and required legacy field aliases are present.
-6. Existing payment configs are normalized and verified against current frontend-visible capabilities.
-7. New frontend flows are verified against mixed-version backend compatibility windows.
-8. Duplicate-account creation, first-bind grants, and payment route selection have regression coverage.
+3. Historical LinuxDo/WeChat/OIDC synthetic-email backfill to canonical identity is complete where deterministic, and non-recoverable rows are reported.
+4. Historical WeChat `openid`-only rows are either remediated or explicitly blocked with operator-visible reporting.
+5. `signup_source` backfill is accurate for known historical provider-created users.
+6. Dual token acceptance, exchange bridge behavior, and required legacy field aliases are present for the bounded rollout window.
+7. Existing payment configs are normalized and verified against current frontend-visible capabilities.
+8. New frontend flows are verified against mixed-version backend compatibility windows.
+9. Duplicate-account creation, first-bind grants, and payment route selection have regression coverage.
--
GitLab
From d3d42677311518b10b738e4180aefd484125ecac Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Mon, 20 Apr 2026 16:23:42 +0800
Subject: [PATCH 031/261] fix: harden oidc callback security
---
backend/internal/config/config_test.go | 2 +-
backend/internal/handler/auth_oidc_oauth.go | 203 ++++++++++++------
.../internal/handler/auth_oidc_oauth_test.go | 15 --
frontend/src/views/admin/SettingsView.vue | 10 +-
4 files changed, 143 insertions(+), 87 deletions(-)
diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go
index cf58316c..964dbb88 100644
--- a/backend/internal/config/config_test.go
+++ b/backend/internal/config/config_test.go
@@ -334,7 +334,7 @@ func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) {
cfg.LinuxDo.ClientSecret = "test-secret"
cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback"
cfg.LinuxDo.TokenAuthMethod = "client_secret_post"
- cfg.LinuxDo.UsePKCE = false
+ cfg.LinuxDo.UsePKCE = true
cfg.LinuxDo.FrontendRedirectURL = "javascript:alert(1)"
err = cfg.Validate()
diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go
index 9d24df88..37ef6833 100644
--- a/backend/internal/handler/auth_oidc_oauth.go
+++ b/backend/internal/handler/auth_oidc_oauth.go
@@ -127,30 +127,34 @@ func (h *AuthHandler) OIDCOAuthStart(c *gin.Context) {
redirectTo = oidcOAuthDefaultRedirectTo
}
+ browserSessionKey, err := generateOAuthPendingBrowserSession()
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BROWSER_SESSION_GEN_FAILED", "failed to generate oauth browser session").WithCause(err))
+ return
+ }
+
secureCookie := isRequestHTTPS(c)
oidcSetCookie(c, oidcOAuthStateCookieName, encodeCookieValue(state), oidcOAuthCookieMaxAgeSec, secureCookie)
oidcSetCookie(c, oidcOAuthRedirectCookie, encodeCookieValue(redirectTo), oidcOAuthCookieMaxAgeSec, secureCookie)
+ setOAuthPendingBrowserCookie(c, browserSessionKey, secureCookie)
+ clearOAuthPendingSessionCookie(c, secureCookie)
codeChallenge := ""
- if cfg.UsePKCE {
- verifier, genErr := oauth.GenerateCodeVerifier()
- if genErr != nil {
- response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(genErr))
- return
- }
- codeChallenge = oauth.GenerateCodeChallenge(verifier)
- oidcSetCookie(c, oidcOAuthVerifierCookie, encodeCookieValue(verifier), oidcOAuthCookieMaxAgeSec, secureCookie)
+ verifier, genErr := oauth.GenerateCodeVerifier()
+ if genErr != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(genErr))
+ return
}
+ codeChallenge = oauth.GenerateCodeChallenge(verifier)
+ oidcSetCookie(c, oidcOAuthVerifierCookie, encodeCookieValue(verifier), oidcOAuthCookieMaxAgeSec, secureCookie)
nonce := ""
- if cfg.ValidateIDToken {
- nonce, err = oauth.GenerateState()
- if err != nil {
- response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_NONCE_GEN_FAILED", "failed to generate oauth nonce").WithCause(err))
- return
- }
- oidcSetCookie(c, oidcOAuthNonceCookie, encodeCookieValue(nonce), oidcOAuthCookieMaxAgeSec, secureCookie)
+ nonce, err = oauth.GenerateState()
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_NONCE_GEN_FAILED", "failed to generate oauth nonce").WithCause(err))
+ return
}
+ oidcSetCookie(c, oidcOAuthNonceCookie, encodeCookieValue(nonce), oidcOAuthCookieMaxAgeSec, secureCookie)
redirectURI := strings.TrimSpace(cfg.RedirectURL)
if redirectURI == "" {
@@ -212,23 +216,24 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
if redirectTo == "" {
redirectTo = oidcOAuthDefaultRedirectTo
}
+ browserSessionKey, _ := readOAuthPendingBrowserCookie(c)
+ if strings.TrimSpace(browserSessionKey) == "" {
+ redirectOAuthError(c, frontendCallback, "missing_browser_session", "missing oauth browser session", "")
+ return
+ }
codeVerifier := ""
- if cfg.UsePKCE {
- codeVerifier, _ = readCookieDecoded(c, oidcOAuthVerifierCookie)
- if codeVerifier == "" {
- redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "")
- return
- }
+ codeVerifier, _ = readCookieDecoded(c, oidcOAuthVerifierCookie)
+ if codeVerifier == "" {
+ redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "")
+ return
}
expectedNonce := ""
- if cfg.ValidateIDToken {
- expectedNonce, _ = readCookieDecoded(c, oidcOAuthNonceCookie)
- if expectedNonce == "" {
- redirectOAuthError(c, frontendCallback, "missing_nonce", "missing oauth nonce", "")
- return
- }
+ expectedNonce, _ = readCookieDecoded(c, oidcOAuthNonceCookie)
+ if expectedNonce == "" {
+ redirectOAuthError(c, frontendCallback, "missing_nonce", "missing oauth nonce", "")
+ return
}
redirectURI := strings.TrimSpace(cfg.RedirectURL)
@@ -258,7 +263,7 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
return
}
- if cfg.ValidateIDToken && strings.TrimSpace(tokenResp.IDToken) == "" {
+ if strings.TrimSpace(tokenResp.IDToken) == "" {
redirectOAuthError(c, frontendCallback, "missing_id_token", "missing id_token", "")
return
}
@@ -304,9 +309,13 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
return
}
}
+ if userInfoClaims.Subject != "" && idClaims.Subject != "" && strings.TrimSpace(userInfoClaims.Subject) != strings.TrimSpace(idClaims.Subject) {
+ redirectOAuthError(c, frontendCallback, "subject_mismatch", "userinfo subject does not match id_token", "")
+ return
+ }
identityKey := oidcIdentityKey(issuer, subject)
- email := oidcSelectLoginEmail(userInfoClaims.Email, idClaims.Email, identityKey)
+ email := oidcSyntheticEmailFromIdentityKey(identityKey)
username := firstNonEmpty(
userInfoClaims.Username,
idClaims.PreferredUsername,
@@ -318,34 +327,73 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
if err != nil {
if errors.Is(err, service.ErrOAuthInvitationRequired) {
- pendingToken, tokenErr := h.authService.CreatePendingOAuthToken(email, username)
- if tokenErr != nil {
- redirectOAuthError(c, frontendCallback, "login_failed", "service_error", "")
+ if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: "login",
+ Identity: service.PendingAuthIdentityKey{
+ ProviderType: "oidc",
+ ProviderKey: issuer,
+ ProviderSubject: subject,
+ },
+ ResolvedEmail: email,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: map[string]any{
+ "email": email,
+ "username": username,
+ "subject": subject,
+ "issuer": issuer,
+ "email_verified": emailVerified != nil && *emailVerified,
+ "provider_fallback": strings.TrimSpace(cfg.ProviderName),
+ },
+ CompletionResponse: map[string]any{
+ "error": "invitation_required",
+ "redirect": redirectTo,
+ },
+ }); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
return
}
- fragment := url.Values{}
- fragment.Set("error", "invitation_required")
- fragment.Set("pending_oauth_token", pendingToken)
- fragment.Set("redirect", redirectTo)
- redirectWithFragment(c, frontendCallback, fragment)
+ redirectToFrontendCallback(c, frontendCallback)
return
}
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
return
}
- fragment := url.Values{}
- fragment.Set("access_token", tokenPair.AccessToken)
- fragment.Set("refresh_token", tokenPair.RefreshToken)
- fragment.Set("expires_in", fmt.Sprintf("%d", tokenPair.ExpiresIn))
- fragment.Set("token_type", "Bearer")
- fragment.Set("redirect", redirectTo)
- redirectWithFragment(c, frontendCallback, fragment)
+ if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: "login",
+ Identity: service.PendingAuthIdentityKey{
+ ProviderType: "oidc",
+ ProviderKey: issuer,
+ ProviderSubject: subject,
+ },
+ ResolvedEmail: email,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: map[string]any{
+ "email": email,
+ "username": username,
+ "subject": subject,
+ "issuer": issuer,
+ "email_verified": emailVerified != nil && *emailVerified,
+ "provider_fallback": strings.TrimSpace(cfg.ProviderName),
+ },
+ CompletionResponse: map[string]any{
+ "access_token": tokenPair.AccessToken,
+ "refresh_token": tokenPair.RefreshToken,
+ "expires_in": tokenPair.ExpiresIn,
+ "token_type": "Bearer",
+ "redirect": redirectTo,
+ },
+ }); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
}
type completeOIDCOAuthRequest struct {
- PendingOAuthToken string `json:"pending_oauth_token" binding:"required"`
- InvitationCode string `json:"invitation_code" binding:"required"`
+ InvitationCode string `json:"invitation_code" binding:"required"`
}
// CompleteOIDCOAuthRegistration completes a pending OAuth registration by validating
@@ -358,9 +406,38 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
return
}
- email, username, err := h.authService.VerifyPendingOAuthToken(req.PendingOAuthToken)
+ secureCookie := isRequestHTTPS(c)
+ sessionToken, err := readOAuthPendingSessionCookie(c)
if err != nil {
- c.JSON(http.StatusUnauthorized, gin.H{"error": "INVALID_TOKEN", "message": "invalid or expired registration token"})
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, service.ErrPendingAuthSessionNotFound)
+ return
+ }
+ browserSessionKey, err := readOAuthPendingBrowserCookie(c)
+ if err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, service.ErrPendingAuthBrowserMismatch)
+ return
+ }
+ pendingSvc, err := h.pendingIdentityService()
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ session, err := pendingSvc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey)
+ if err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ email := strings.TrimSpace(session.ResolvedEmail)
+ username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username")
+ if email == "" || username == "" {
+ response.ErrorFrom(c, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid"))
return
}
@@ -369,6 +446,14 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
+ if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, err)
+ return
+ }
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
c.JSON(http.StatusOK, gin.H{
"access_token": tokenPair.AccessToken,
@@ -405,9 +490,7 @@ func oidcExchangeCode(
form.Set("client_id", cfg.ClientID)
form.Set("code", code)
form.Set("redirect_uri", redirectURI)
- if cfg.UsePKCE {
- form.Set("code_verifier", codeVerifier)
- }
+ form.Set("code_verifier", codeVerifier)
r := client.R().
SetContext(ctx).
@@ -592,13 +675,9 @@ func buildOIDCAuthorizeURL(cfg config.OIDCConnectConfig, state, nonce, codeChall
q.Set("scope", cfg.Scopes)
}
q.Set("state", state)
- if strings.TrimSpace(nonce) != "" {
- q.Set("nonce", nonce)
- }
- if cfg.UsePKCE {
- q.Set("code_challenge", codeChallenge)
- q.Set("code_challenge_method", "S256")
- }
+ q.Set("nonce", nonce)
+ q.Set("code_challenge", codeChallenge)
+ q.Set("code_challenge_method", "S256")
u.RawQuery = q.Encode()
return u.String(), nil
@@ -831,14 +910,6 @@ func oidcSyntheticEmailFromIdentityKey(identityKey string) string {
return "oidc-" + hex.EncodeToString(sum[:16]) + service.OIDCConnectSyntheticEmailDomain
}
-func oidcSelectLoginEmail(userInfoEmail, idTokenEmail, identityKey string) string {
- email := strings.TrimSpace(firstNonEmpty(userInfoEmail, idTokenEmail))
- if email != "" {
- return email
- }
- return oidcSyntheticEmailFromIdentityKey(identityKey)
-}
-
func oidcFallbackUsername(subject string) string {
subject = strings.TrimSpace(subject)
if subject == "" {
diff --git a/backend/internal/handler/auth_oidc_oauth_test.go b/backend/internal/handler/auth_oidc_oauth_test.go
index a161aa77..a4cf776a 100644
--- a/backend/internal/handler/auth_oidc_oauth_test.go
+++ b/backend/internal/handler/auth_oidc_oauth_test.go
@@ -30,26 +30,11 @@ func TestOIDCSyntheticEmailStableAndDistinct(t *testing.T) {
require.Contains(t, e1, "@oidc-connect.invalid")
}
-func TestOIDCSelectLoginEmailPrefersRealEmail(t *testing.T) {
- identityKey := oidcIdentityKey("https://issuer.example.com", "subject-a")
-
- email := oidcSelectLoginEmail("user@example.com", "idtoken@example.com", identityKey)
- require.Equal(t, "user@example.com", email)
-
- email = oidcSelectLoginEmail("", "idtoken@example.com", identityKey)
- require.Equal(t, "idtoken@example.com", email)
-
- email = oidcSelectLoginEmail("", "", identityKey)
- require.Contains(t, email, "@oidc-connect.invalid")
- require.Equal(t, oidcSyntheticEmailFromIdentityKey(identityKey), email)
-}
-
func TestBuildOIDCAuthorizeURLIncludesNonceAndPKCE(t *testing.T) {
cfg := config.OIDCConnectConfig{
AuthorizeURL: "https://issuer.example.com/auth",
ClientID: "cid",
Scopes: "openid email profile",
- UsePKCE: true,
}
u, err := buildOIDCAuthorizeURL(cfg, "state123", "nonce123", "challenge123", "https://app.example.com/callback")
diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue
index ee6a4c6d..8bfa0f2b 100644
--- a/frontend/src/views/admin/SettingsView.vue
+++ b/frontend/src/views/admin/SettingsView.vue
@@ -1382,7 +1382,7 @@
{{ t('admin.settings.oidc.usePkce') }}
-
+
@@ -1391,7 +1391,7 @@
{{ t('admin.settings.oidc.validateIdToken') }}
-
+
@@ -3024,7 +3024,7 @@ const form = reactive
({
oidc_connect_redirect_url: '',
oidc_connect_frontend_redirect_url: '/auth/oidc/callback',
oidc_connect_token_auth_method: 'client_secret_post',
- oidc_connect_use_pkce: false,
+ oidc_connect_use_pkce: true,
oidc_connect_validate_id_token: true,
oidc_connect_allowed_signing_algs: 'RS256,ES256,PS256',
oidc_connect_clock_skew_seconds: 120,
@@ -3613,8 +3613,8 @@ async function saveSettings() {
oidc_connect_redirect_url: form.oidc_connect_redirect_url,
oidc_connect_frontend_redirect_url: form.oidc_connect_frontend_redirect_url,
oidc_connect_token_auth_method: form.oidc_connect_token_auth_method,
- oidc_connect_use_pkce: form.oidc_connect_use_pkce,
- oidc_connect_validate_id_token: form.oidc_connect_validate_id_token,
+ oidc_connect_use_pkce: true,
+ oidc_connect_validate_id_token: true,
oidc_connect_allowed_signing_algs: form.oidc_connect_allowed_signing_algs,
oidc_connect_clock_skew_seconds: form.oidc_connect_clock_skew_seconds,
oidc_connect_require_email_verified: form.oidc_connect_require_email_verified,
--
GitLab
From fbd0a2e3c488720025a3408e9db234407d8aef9b Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Mon, 20 Apr 2026 16:27:23 +0800
Subject: [PATCH 032/261] feat: carry suggested third-party profile through
pending oauth
---
.../internal/handler/auth_linuxdo_oauth.go | 201 +++++++++----
.../handler/auth_linuxdo_oauth_test.go | 12 +-
.../handler/auth_oauth_pending_flow.go | 263 ++++++++++++++++++
.../handler/auth_oauth_pending_flow_test.go | 40 +++
backend/internal/handler/auth_oidc_oauth.go | 47 +++-
.../internal/handler/auth_oidc_oauth_test.go | 20 ++
frontend/src/api/auth.ts | 24 +-
7 files changed, 534 insertions(+), 73 deletions(-)
create mode 100644 backend/internal/handler/auth_oauth_pending_flow.go
create mode 100644 backend/internal/handler/auth_oauth_pending_flow_test.go
diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go
index 0c7c2da7..2f182642 100644
--- a/backend/internal/handler/auth_linuxdo_oauth.go
+++ b/backend/internal/handler/auth_linuxdo_oauth.go
@@ -87,20 +87,25 @@ func (h *AuthHandler) LinuxDoOAuthStart(c *gin.Context) {
redirectTo = linuxDoOAuthDefaultRedirectTo
}
+ browserSessionKey, err := generateOAuthPendingBrowserSession()
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BROWSER_SESSION_GEN_FAILED", "failed to generate oauth browser session").WithCause(err))
+ return
+ }
+
secureCookie := isRequestHTTPS(c)
setCookie(c, linuxDoOAuthStateCookieName, encodeCookieValue(state), linuxDoOAuthCookieMaxAgeSec, secureCookie)
setCookie(c, linuxDoOAuthRedirectCookie, encodeCookieValue(redirectTo), linuxDoOAuthCookieMaxAgeSec, secureCookie)
+ setOAuthPendingBrowserCookie(c, browserSessionKey, secureCookie)
+ clearOAuthPendingSessionCookie(c, secureCookie)
- codeChallenge := ""
- if cfg.UsePKCE {
- verifier, err := oauth.GenerateCodeVerifier()
- if err != nil {
- response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(err))
- return
- }
- codeChallenge = oauth.GenerateCodeChallenge(verifier)
- setCookie(c, linuxDoOAuthVerifierCookie, encodeCookieValue(verifier), linuxDoOAuthCookieMaxAgeSec, secureCookie)
+ verifier, err := oauth.GenerateCodeVerifier()
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(err))
+ return
}
+ codeChallenge := oauth.GenerateCodeChallenge(verifier)
+ setCookie(c, linuxDoOAuthVerifierCookie, encodeCookieValue(verifier), linuxDoOAuthCookieMaxAgeSec, secureCookie)
redirectURI := strings.TrimSpace(cfg.RedirectURL)
if redirectURI == "" {
@@ -161,14 +166,16 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
if redirectTo == "" {
redirectTo = linuxDoOAuthDefaultRedirectTo
}
+ browserSessionKey, _ := readOAuthPendingBrowserCookie(c)
+ if strings.TrimSpace(browserSessionKey) == "" {
+ redirectOAuthError(c, frontendCallback, "missing_browser_session", "missing oauth browser session", "")
+ return
+ }
- codeVerifier := ""
- if cfg.UsePKCE {
- codeVerifier, _ = readCookieDecoded(c, linuxDoOAuthVerifierCookie)
- if codeVerifier == "" {
- redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "")
- return
- }
+ codeVerifier, _ := readCookieDecoded(c, linuxDoOAuthVerifierCookie)
+ if codeVerifier == "" {
+ redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "")
+ return
}
redirectURI := strings.TrimSpace(cfg.RedirectURL)
@@ -198,7 +205,7 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
return
}
- email, username, subject, err := linuxDoFetchUserInfo(c.Request.Context(), cfg, tokenResp)
+ email, username, subject, displayName, avatarURL, err := linuxDoFetchUserInfo(c.Request.Context(), cfg, tokenResp)
if err != nil {
log.Printf("[LinuxDo OAuth] userinfo fetch failed: %v", err)
redirectOAuthError(c, frontendCallback, "userinfo_failed", "failed to fetch user info", "")
@@ -215,16 +222,32 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
if err != nil {
if errors.Is(err, service.ErrOAuthInvitationRequired) {
- pendingToken, tokenErr := h.authService.CreatePendingOAuthToken(email, username)
- if tokenErr != nil {
- redirectOAuthError(c, frontendCallback, "login_failed", "service_error", "")
+ if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: "login",
+ Identity: service.PendingAuthIdentityKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: subject,
+ },
+ ResolvedEmail: email,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: map[string]any{
+ "email": email,
+ "username": username,
+ "subject": subject,
+ "suggested_display_name": displayName,
+ "suggested_avatar_url": avatarURL,
+ },
+ CompletionResponse: map[string]any{
+ "error": "invitation_required",
+ "redirect": redirectTo,
+ },
+ }); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
return
}
- fragment := url.Values{}
- fragment.Set("error", "invitation_required")
- fragment.Set("pending_oauth_token", pendingToken)
- fragment.Set("redirect", redirectTo)
- redirectWithFragment(c, frontendCallback, fragment)
+ redirectToFrontendCallback(c, frontendCallback)
return
}
// 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。
@@ -232,18 +255,39 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
return
}
- fragment := url.Values{}
- fragment.Set("access_token", tokenPair.AccessToken)
- fragment.Set("refresh_token", tokenPair.RefreshToken)
- fragment.Set("expires_in", fmt.Sprintf("%d", tokenPair.ExpiresIn))
- fragment.Set("token_type", "Bearer")
- fragment.Set("redirect", redirectTo)
- redirectWithFragment(c, frontendCallback, fragment)
+ if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: "login",
+ Identity: service.PendingAuthIdentityKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: subject,
+ },
+ ResolvedEmail: email,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: map[string]any{
+ "email": email,
+ "username": username,
+ "subject": subject,
+ "suggested_display_name": displayName,
+ "suggested_avatar_url": avatarURL,
+ },
+ CompletionResponse: map[string]any{
+ "access_token": tokenPair.AccessToken,
+ "refresh_token": tokenPair.RefreshToken,
+ "expires_in": tokenPair.ExpiresIn,
+ "token_type": "Bearer",
+ "redirect": redirectTo,
+ },
+ }); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
}
type completeLinuxDoOAuthRequest struct {
- PendingOAuthToken string `json:"pending_oauth_token" binding:"required"`
- InvitationCode string `json:"invitation_code" binding:"required"`
+ InvitationCode string `json:"invitation_code" binding:"required"`
}
// CompleteLinuxDoOAuthRegistration completes a pending OAuth registration by validating
@@ -256,9 +300,38 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
return
}
- email, username, err := h.authService.VerifyPendingOAuthToken(req.PendingOAuthToken)
+ secureCookie := isRequestHTTPS(c)
+ sessionToken, err := readOAuthPendingSessionCookie(c)
+ if err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, service.ErrPendingAuthSessionNotFound)
+ return
+ }
+ browserSessionKey, err := readOAuthPendingBrowserCookie(c)
+ if err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, service.ErrPendingAuthBrowserMismatch)
+ return
+ }
+ pendingSvc, err := h.pendingIdentityService()
if err != nil {
- c.JSON(http.StatusUnauthorized, gin.H{"error": "INVALID_TOKEN", "message": "invalid or expired registration token"})
+ response.ErrorFrom(c, err)
+ return
+ }
+ session, err := pendingSvc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey)
+ if err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ email := strings.TrimSpace(session.ResolvedEmail)
+ username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username")
+ if email == "" || username == "" {
+ response.ErrorFrom(c, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid"))
return
}
@@ -267,6 +340,14 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
+ if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, err)
+ return
+ }
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
c.JSON(http.StatusOK, gin.H{
"access_token": tokenPair.AccessToken,
@@ -303,9 +384,7 @@ func linuxDoExchangeCode(
form.Set("client_id", cfg.ClientID)
form.Set("code", code)
form.Set("redirect_uri", redirectURI)
- if cfg.UsePKCE {
- form.Set("code_verifier", codeVerifier)
- }
+ form.Set("code_verifier", codeVerifier)
r := client.R().
SetContext(ctx).
@@ -353,11 +432,11 @@ func linuxDoFetchUserInfo(
ctx context.Context,
cfg config.LinuxDoConnectConfig,
token *linuxDoTokenResponse,
-) (email string, username string, subject string, err error) {
+) (email string, username string, subject string, displayName string, avatarURL string, err error) {
client := req.C().SetTimeout(30 * time.Second)
authorization, err := buildBearerAuthorization(token.TokenType, token.AccessToken)
if err != nil {
- return "", "", "", fmt.Errorf("invalid token for userinfo request: %w", err)
+ return "", "", "", "", "", fmt.Errorf("invalid token for userinfo request: %w", err)
}
resp, err := client.R().
@@ -366,16 +445,16 @@ func linuxDoFetchUserInfo(
SetHeader("Authorization", authorization).
Get(cfg.UserInfoURL)
if err != nil {
- return "", "", "", fmt.Errorf("request userinfo: %w", err)
+ return "", "", "", "", "", fmt.Errorf("request userinfo: %w", err)
}
if !resp.IsSuccessState() {
- return "", "", "", fmt.Errorf("userinfo status=%d", resp.StatusCode)
+ return "", "", "", "", "", fmt.Errorf("userinfo status=%d", resp.StatusCode)
}
return linuxDoParseUserInfo(resp.String(), cfg)
}
-func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email string, username string, subject string, err error) {
+func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email string, username string, subject string, displayName string, avatarURL string, err error) {
email = firstNonEmpty(
getGJSON(body, cfg.UserInfoEmailPath),
getGJSON(body, "email"),
@@ -400,12 +479,29 @@ func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email s
getGJSON(body, "user.id"),
)
+ displayName = firstNonEmpty(
+ getGJSON(body, "name"),
+ getGJSON(body, "nickname"),
+ getGJSON(body, "display_name"),
+ getGJSON(body, "user.name"),
+ getGJSON(body, "user.username"),
+ username,
+ )
+ avatarURL = firstNonEmpty(
+ getGJSON(body, "avatar_url"),
+ getGJSON(body, "avatar"),
+ getGJSON(body, "picture"),
+ getGJSON(body, "profile_image_url"),
+ getGJSON(body, "user.avatar"),
+ getGJSON(body, "user.avatar_url"),
+ )
+
subject = strings.TrimSpace(subject)
if subject == "" {
- return "", "", "", errors.New("userinfo missing id field")
+ return "", "", "", "", "", errors.New("userinfo missing id field")
}
if !isSafeLinuxDoSubject(subject) {
- return "", "", "", errors.New("userinfo returned invalid id field")
+ return "", "", "", "", "", errors.New("userinfo returned invalid id field")
}
email = strings.TrimSpace(email)
@@ -418,8 +514,13 @@ func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email s
if username == "" {
username = "linuxdo_" + subject
}
+ displayName = strings.TrimSpace(displayName)
+ if displayName == "" {
+ displayName = username
+ }
+ avatarURL = strings.TrimSpace(avatarURL)
- return email, username, subject, nil
+ return email, username, subject, displayName, avatarURL, nil
}
func buildLinuxDoAuthorizeURL(cfg config.LinuxDoConnectConfig, state string, codeChallenge string, redirectURI string) (string, error) {
@@ -436,10 +537,8 @@ func buildLinuxDoAuthorizeURL(cfg config.LinuxDoConnectConfig, state string, cod
q.Set("scope", cfg.Scopes)
}
q.Set("state", state)
- if cfg.UsePKCE {
- q.Set("code_challenge", codeChallenge)
- q.Set("code_challenge_method", "S256")
- }
+ q.Set("code_challenge", codeChallenge)
+ q.Set("code_challenge_method", "S256")
u.RawQuery = q.Encode()
return u.String(), nil
diff --git a/backend/internal/handler/auth_linuxdo_oauth_test.go b/backend/internal/handler/auth_linuxdo_oauth_test.go
index ff169c52..90bc10d1 100644
--- a/backend/internal/handler/auth_linuxdo_oauth_test.go
+++ b/backend/internal/handler/auth_linuxdo_oauth_test.go
@@ -41,11 +41,13 @@ func TestLinuxDoParseUserInfoParsesIDAndUsername(t *testing.T) {
UserInfoURL: "https://connect.linux.do/api/user",
}
- email, username, subject, err := linuxDoParseUserInfo(`{"id":123,"username":"alice"}`, cfg)
+ email, username, subject, displayName, avatarURL, err := linuxDoParseUserInfo(`{"id":123,"username":"alice","name":"Alice","avatar_url":"https://cdn.example/avatar.png"}`, cfg)
require.NoError(t, err)
require.Equal(t, "123", subject)
require.Equal(t, "alice", username)
require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email)
+ require.Equal(t, "Alice", displayName)
+ require.Equal(t, "https://cdn.example/avatar.png", avatarURL)
}
func TestLinuxDoParseUserInfoDefaultsUsername(t *testing.T) {
@@ -53,11 +55,13 @@ func TestLinuxDoParseUserInfoDefaultsUsername(t *testing.T) {
UserInfoURL: "https://connect.linux.do/api/user",
}
- email, username, subject, err := linuxDoParseUserInfo(`{"id":"123"}`, cfg)
+ email, username, subject, displayName, avatarURL, err := linuxDoParseUserInfo(`{"id":"123"}`, cfg)
require.NoError(t, err)
require.Equal(t, "123", subject)
require.Equal(t, "linuxdo_123", username)
require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email)
+ require.Equal(t, "linuxdo_123", displayName)
+ require.Equal(t, "", avatarURL)
}
func TestLinuxDoParseUserInfoRejectsUnsafeSubject(t *testing.T) {
@@ -65,11 +69,11 @@ func TestLinuxDoParseUserInfoRejectsUnsafeSubject(t *testing.T) {
UserInfoURL: "https://connect.linux.do/api/user",
}
- _, _, _, err := linuxDoParseUserInfo(`{"id":"123@456"}`, cfg)
+ _, _, _, _, _, err := linuxDoParseUserInfo(`{"id":"123@456"}`, cfg)
require.Error(t, err)
tooLong := strings.Repeat("a", linuxDoOAuthMaxSubjectLen+1)
- _, _, _, err = linuxDoParseUserInfo(`{"id":"`+tooLong+`"}`, cfg)
+ _, _, _, _, _, err = linuxDoParseUserInfo(`{"id":"`+tooLong+`"}`, cfg)
require.Error(t, err)
}
diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go
new file mode 100644
index 00000000..a758c0b9
--- /dev/null
+++ b/backend/internal/handler/auth_oauth_pending_flow.go
@@ -0,0 +1,263 @@
+package handler
+
+import (
+ "net/http"
+ "net/url"
+ "strings"
+
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+
+ "github.com/gin-gonic/gin"
+)
+
+const (
+ oauthPendingBrowserCookiePath = "/api/v1/auth/oauth"
+ oauthPendingBrowserCookieName = "oauth_pending_browser_session"
+ oauthPendingSessionCookiePath = "/api/v1/auth/oauth/pending"
+ oauthPendingSessionCookieName = "oauth_pending_session"
+ oauthPendingCookieMaxAgeSec = 10 * 60
+
+ oauthCompletionResponseKey = "completion_response"
+)
+
+type oauthPendingSessionPayload struct {
+ Intent string
+ Identity service.PendingAuthIdentityKey
+ ResolvedEmail string
+ RedirectTo string
+ BrowserSessionKey string
+ UpstreamIdentityClaims map[string]any
+ CompletionResponse map[string]any
+}
+
+func (h *AuthHandler) pendingIdentityService() (*service.AuthPendingIdentityService, error) {
+ if h == nil || h.authService == nil || h.authService.EntClient() == nil {
+ return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+ return service.NewAuthPendingIdentityService(h.authService.EntClient()), nil
+}
+
+func generateOAuthPendingBrowserSession() (string, error) {
+ return oauth.GenerateState()
+}
+
+func setOAuthPendingBrowserCookie(c *gin.Context, sessionKey string, secure bool) {
+ http.SetCookie(c.Writer, &http.Cookie{
+ Name: oauthPendingBrowserCookieName,
+ Value: encodeCookieValue(sessionKey),
+ Path: oauthPendingBrowserCookiePath,
+ MaxAge: oauthPendingCookieMaxAgeSec,
+ HttpOnly: true,
+ Secure: secure,
+ SameSite: http.SameSiteLaxMode,
+ })
+}
+
+func clearOAuthPendingBrowserCookie(c *gin.Context, secure bool) {
+ http.SetCookie(c.Writer, &http.Cookie{
+ Name: oauthPendingBrowserCookieName,
+ Value: "",
+ Path: oauthPendingBrowserCookiePath,
+ MaxAge: -1,
+ HttpOnly: true,
+ Secure: secure,
+ SameSite: http.SameSiteLaxMode,
+ })
+}
+
+func readOAuthPendingBrowserCookie(c *gin.Context) (string, error) {
+ return readCookieDecoded(c, oauthPendingBrowserCookieName)
+}
+
+func setOAuthPendingSessionCookie(c *gin.Context, sessionToken string, secure bool) {
+ http.SetCookie(c.Writer, &http.Cookie{
+ Name: oauthPendingSessionCookieName,
+ Value: encodeCookieValue(sessionToken),
+ Path: oauthPendingSessionCookiePath,
+ MaxAge: oauthPendingCookieMaxAgeSec,
+ HttpOnly: true,
+ Secure: secure,
+ SameSite: http.SameSiteLaxMode,
+ })
+}
+
+func clearOAuthPendingSessionCookie(c *gin.Context, secure bool) {
+ http.SetCookie(c.Writer, &http.Cookie{
+ Name: oauthPendingSessionCookieName,
+ Value: "",
+ Path: oauthPendingSessionCookiePath,
+ MaxAge: -1,
+ HttpOnly: true,
+ Secure: secure,
+ SameSite: http.SameSiteLaxMode,
+ })
+}
+
+func readOAuthPendingSessionCookie(c *gin.Context) (string, error) {
+ return readCookieDecoded(c, oauthPendingSessionCookieName)
+}
+
+func redirectToFrontendCallback(c *gin.Context, frontendCallback string) {
+ u, err := url.Parse(frontendCallback)
+ if err != nil {
+ c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo)
+ return
+ }
+ if u.Scheme != "" && !strings.EqualFold(u.Scheme, "http") && !strings.EqualFold(u.Scheme, "https") {
+ c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo)
+ return
+ }
+ u.Fragment = ""
+ c.Header("Cache-Control", "no-store")
+ c.Header("Pragma", "no-cache")
+ c.Redirect(http.StatusFound, u.String())
+}
+
+func (h *AuthHandler) createOAuthPendingSession(c *gin.Context, payload oauthPendingSessionPayload) error {
+ svc, err := h.pendingIdentityService()
+ if err != nil {
+ return err
+ }
+
+ session, err := svc.CreatePendingSession(c.Request.Context(), service.CreatePendingAuthSessionInput{
+ Intent: strings.TrimSpace(payload.Intent),
+ Identity: payload.Identity,
+ ResolvedEmail: strings.TrimSpace(payload.ResolvedEmail),
+ RedirectTo: strings.TrimSpace(payload.RedirectTo),
+ BrowserSessionKey: strings.TrimSpace(payload.BrowserSessionKey),
+ UpstreamIdentityClaims: payload.UpstreamIdentityClaims,
+ LocalFlowState: map[string]any{
+ oauthCompletionResponseKey: payload.CompletionResponse,
+ },
+ })
+ if err != nil {
+ return infraerrors.InternalServer("PENDING_AUTH_SESSION_CREATE_FAILED", "failed to create pending auth session").WithCause(err)
+ }
+
+ setOAuthPendingSessionCookie(c, session.SessionToken, isRequestHTTPS(c))
+ return nil
+}
+
+func readCompletionResponse(session map[string]any) (map[string]any, bool) {
+ if len(session) == 0 {
+ return nil, false
+ }
+ value, ok := session[oauthCompletionResponseKey]
+ if !ok {
+ return nil, false
+ }
+ result, ok := value.(map[string]any)
+ if !ok {
+ return nil, false
+ }
+ return result, true
+}
+
+func pendingSessionStringValue(values map[string]any, key string) string {
+ if len(values) == 0 {
+ return ""
+ }
+ raw, ok := values[key]
+ if !ok {
+ return ""
+ }
+ value, ok := raw.(string)
+ if !ok {
+ return ""
+ }
+ return strings.TrimSpace(value)
+}
+
+func pendingSessionWantsInvitation(payload map[string]any) bool {
+ return strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "error")), "invitation_required")
+}
+
+func applySuggestedProfileToCompletionResponse(payload map[string]any, upstream map[string]any) {
+ if len(payload) == 0 || len(upstream) == 0 {
+ return
+ }
+
+ displayName := pendingSessionStringValue(upstream, "suggested_display_name")
+ avatarURL := pendingSessionStringValue(upstream, "suggested_avatar_url")
+
+ if displayName != "" {
+ if _, exists := payload["suggested_display_name"]; !exists {
+ payload["suggested_display_name"] = displayName
+ }
+ }
+ if avatarURL != "" {
+ if _, exists := payload["suggested_avatar_url"]; !exists {
+ payload["suggested_avatar_url"] = avatarURL
+ }
+ }
+ if displayName != "" || avatarURL != "" {
+ payload["adoption_required"] = true
+ }
+}
+
+// ExchangePendingOAuthCompletion redeems a pending OAuth browser session into a frontend-safe payload.
+// POST /api/v1/auth/oauth/pending/exchange
+func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
+ secureCookie := isRequestHTTPS(c)
+ clearCookies := func() {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ }
+
+ sessionToken, err := readOAuthPendingSessionCookie(c)
+ if err != nil || strings.TrimSpace(sessionToken) == "" {
+ clearCookies()
+ response.ErrorFrom(c, service.ErrPendingAuthSessionNotFound)
+ return
+ }
+ browserSessionKey, err := readOAuthPendingBrowserCookie(c)
+ if err != nil || strings.TrimSpace(browserSessionKey) == "" {
+ clearCookies()
+ response.ErrorFrom(c, service.ErrPendingAuthBrowserMismatch)
+ return
+ }
+
+ svc, err := h.pendingIdentityService()
+ if err != nil {
+ clearCookies()
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ session, err := svc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey)
+ if err != nil {
+ clearCookies()
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ payload, ok := readCompletionResponse(session.LocalFlowState)
+ if !ok {
+ clearCookies()
+ response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_COMPLETION_INVALID", "pending auth completion payload is invalid"))
+ return
+ }
+ if strings.TrimSpace(session.RedirectTo) != "" {
+ if _, exists := payload["redirect"]; !exists {
+ payload["redirect"] = session.RedirectTo
+ }
+ }
+ applySuggestedProfileToCompletionResponse(payload, session.UpstreamIdentityClaims)
+
+ if pendingSessionWantsInvitation(payload) {
+ response.Success(c, payload)
+ return
+ }
+
+ if _, err := svc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
+ clearCookies()
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ clearCookies()
+ response.Success(c, payload)
+}
diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go
new file mode 100644
index 00000000..5517bae2
--- /dev/null
+++ b/backend/internal/handler/auth_oauth_pending_flow_test.go
@@ -0,0 +1,40 @@
+package handler
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestApplySuggestedProfileToCompletionResponse(t *testing.T) {
+ payload := map[string]any{
+ "access_token": "token",
+ }
+ upstream := map[string]any{
+ "suggested_display_name": "Alice",
+ "suggested_avatar_url": "https://cdn.example/avatar.png",
+ }
+
+ applySuggestedProfileToCompletionResponse(payload, upstream)
+
+ require.Equal(t, "Alice", payload["suggested_display_name"])
+ require.Equal(t, "https://cdn.example/avatar.png", payload["suggested_avatar_url"])
+ require.Equal(t, true, payload["adoption_required"])
+}
+
+func TestApplySuggestedProfileToCompletionResponseKeepsExistingPayloadValues(t *testing.T) {
+ payload := map[string]any{
+ "suggested_display_name": "Existing",
+ "adoption_required": false,
+ }
+ upstream := map[string]any{
+ "suggested_display_name": "Alice",
+ "suggested_avatar_url": "https://cdn.example/avatar.png",
+ }
+
+ applySuggestedProfileToCompletionResponse(payload, upstream)
+
+ require.Equal(t, "Existing", payload["suggested_display_name"])
+ require.Equal(t, "https://cdn.example/avatar.png", payload["suggested_avatar_url"])
+ require.Equal(t, true, payload["adoption_required"])
+}
diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go
index 37ef6833..e3694c8f 100644
--- a/backend/internal/handler/auth_oidc_oauth.go
+++ b/backend/internal/handler/auth_oidc_oauth.go
@@ -87,6 +87,8 @@ type oidcUserInfoClaims struct {
Username string
Subject string
EmailVerified *bool
+ DisplayName string
+ AvatarURL string
}
type oidcJWKSet struct {
@@ -338,12 +340,14 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey,
UpstreamIdentityClaims: map[string]any{
- "email": email,
- "username": username,
- "subject": subject,
- "issuer": issuer,
- "email_verified": emailVerified != nil && *emailVerified,
- "provider_fallback": strings.TrimSpace(cfg.ProviderName),
+ "email": email,
+ "username": username,
+ "subject": subject,
+ "issuer": issuer,
+ "email_verified": emailVerified != nil && *emailVerified,
+ "provider_fallback": strings.TrimSpace(cfg.ProviderName),
+ "suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username),
+ "suggested_avatar_url": userInfoClaims.AvatarURL,
},
CompletionResponse: map[string]any{
"error": "invitation_required",
@@ -371,12 +375,14 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey,
UpstreamIdentityClaims: map[string]any{
- "email": email,
- "username": username,
- "subject": subject,
- "issuer": issuer,
- "email_verified": emailVerified != nil && *emailVerified,
- "provider_fallback": strings.TrimSpace(cfg.ProviderName),
+ "email": email,
+ "username": username,
+ "subject": subject,
+ "issuer": issuer,
+ "email_verified": emailVerified != nil && *emailVerified,
+ "provider_fallback": strings.TrimSpace(cfg.ProviderName),
+ "suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username),
+ "suggested_avatar_url": userInfoClaims.AvatarURL,
},
CompletionResponse: map[string]any{
"access_token": tokenPair.AccessToken,
@@ -643,9 +649,26 @@ func oidcParseUserInfo(body string, cfg config.OIDCConnectConfig) *oidcUserInfoC
if verified, ok := getGJSONBool(body, "email_verified"); ok {
claims.EmailVerified = &verified
}
+ claims.DisplayName = firstNonEmpty(
+ getGJSON(body, "name"),
+ getGJSON(body, "nickname"),
+ getGJSON(body, "display_name"),
+ getGJSON(body, "preferred_username"),
+ getGJSON(body, "username"),
+ )
+ claims.AvatarURL = firstNonEmpty(
+ getGJSON(body, "picture"),
+ getGJSON(body, "avatar_url"),
+ getGJSON(body, "avatar"),
+ getGJSON(body, "profile_image_url"),
+ getGJSON(body, "user.avatar"),
+ getGJSON(body, "user.avatar_url"),
+ )
claims.Email = strings.TrimSpace(claims.Email)
claims.Username = strings.TrimSpace(claims.Username)
claims.Subject = strings.TrimSpace(claims.Subject)
+ claims.DisplayName = strings.TrimSpace(claims.DisplayName)
+ claims.AvatarURL = strings.TrimSpace(claims.AvatarURL)
return claims
}
diff --git a/backend/internal/handler/auth_oidc_oauth_test.go b/backend/internal/handler/auth_oidc_oauth_test.go
index a4cf776a..c389db51 100644
--- a/backend/internal/handler/auth_oidc_oauth_test.go
+++ b/backend/internal/handler/auth_oidc_oauth_test.go
@@ -91,6 +91,26 @@ func TestOIDCParseAndValidateIDToken(t *testing.T) {
require.Error(t, err)
}
+func TestOIDCParseUserInfoIncludesSuggestedProfile(t *testing.T) {
+ cfg := config.OIDCConnectConfig{}
+
+ claims := oidcParseUserInfo(`{
+ "sub":"subject-1",
+ "preferred_username":"alice",
+ "name":"Alice Example",
+ "picture":"https://cdn.example/avatar.png",
+ "email":"alice@example.com",
+ "email_verified":true
+ }`, cfg)
+
+ require.Equal(t, "subject-1", claims.Subject)
+ require.Equal(t, "alice", claims.Username)
+ require.Equal(t, "Alice Example", claims.DisplayName)
+ require.Equal(t, "https://cdn.example/avatar.png", claims.AvatarURL)
+ require.NotNil(t, claims.EmailVerified)
+ require.True(t, *claims.EmailVerified)
+}
+
func buildRSAJWK(kid string, pub *rsa.PublicKey) oidcJWK {
n := base64.RawURLEncoding.EncodeToString(pub.N.Bytes())
e := base64.RawURLEncoding.EncodeToString(big.NewInt(int64(pub.E)).Bytes())
diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts
index 837c4f4c..d7abcd6a 100644
--- a/frontend/src/api/auth.ts
+++ b/frontend/src/api/auth.ts
@@ -186,6 +186,18 @@ export interface RefreshTokenResponse {
token_type: string
}
+export interface PendingOAuthExchangeResponse {
+ access_token?: string
+ refresh_token?: string
+ expires_in?: number
+ token_type?: string
+ redirect?: string
+ error?: string
+ adoption_required?: boolean
+ suggested_display_name?: string
+ suggested_avatar_url?: string
+}
+
/**
* Refresh the access token using the refresh token
* @returns New token pair
@@ -337,12 +349,10 @@ export async function resetPassword(request: ResetPasswordRequest): Promise {
const { data } = await apiClient.post<{
@@ -351,7 +361,6 @@ export async function completeLinuxDoOAuthRegistration(
expires_in: number
token_type: string
}>('/auth/oauth/linuxdo/complete-registration', {
- pending_oauth_token: pendingOAuthToken,
invitation_code: invitationCode
})
return data
@@ -359,12 +368,10 @@ export async function completeLinuxDoOAuthRegistration(
/**
* Complete OIDC OAuth registration by supplying an invitation code
- * @param pendingOAuthToken - Short-lived JWT from the OAuth callback
* @param invitationCode - Invitation code entered by the user
* @returns Token pair on success
*/
export async function completeOIDCOAuthRegistration(
- pendingOAuthToken: string,
invitationCode: string
): Promise<{ access_token: string; refresh_token: string; expires_in: number; token_type: string }> {
const { data } = await apiClient.post<{
@@ -373,12 +380,16 @@ export async function completeOIDCOAuthRegistration(
expires_in: number
token_type: string
}>('/auth/oauth/oidc/complete-registration', {
- pending_oauth_token: pendingOAuthToken,
invitation_code: invitationCode
})
return data
}
+export async function exchangePendingOAuthCompletion(): Promise {
+ const { data } = await apiClient.post('/auth/oauth/pending/exchange', {})
+ return data
+}
+
export const authAPI = {
login,
login2FA,
@@ -402,6 +413,7 @@ export const authAPI = {
resetPassword,
refreshToken,
revokeAllSessions,
+ exchangePendingOAuthCompletion,
completeLinuxDoOAuthRegistration,
completeOIDCOAuthRegistration
}
--
GitLab
From e9de839d8791151314873bad11ac9583dabd2738 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Mon, 20 Apr 2026 17:39:57 +0800
Subject: [PATCH 033/261] feat: rebuild auth identity foundation flow
---
backend/ent/authidentity.go | 266 +
backend/ent/authidentity/authidentity.go | 209 +
backend/ent/authidentity/where.go | 600 +
backend/ent/authidentity_create.go | 1036 +
backend/ent/authidentity_delete.go | 88 +
backend/ent/authidentity_query.go | 797 +
backend/ent/authidentity_update.go | 923 +
backend/ent/authidentitychannel.go | 228 +
.../authidentitychannel.go | 153 +
backend/ent/authidentitychannel/where.go | 559 +
backend/ent/authidentitychannel_create.go | 932 +
backend/ent/authidentitychannel_delete.go | 88 +
backend/ent/authidentitychannel_query.go | 643 +
backend/ent/authidentitychannel_update.go | 581 +
backend/ent/client.go | 876 +-
backend/ent/ent.go | 60 +-
backend/ent/hook/hook.go | 48 +
backend/ent/identityadoptiondecision.go | 223 +
.../identityadoptiondecision.go | 159 +
backend/ent/identityadoptiondecision/where.go | 342 +
.../ent/identityadoptiondecision_create.go | 843 +
.../ent/identityadoptiondecision_delete.go | 88 +
backend/ent/identityadoptiondecision_query.go | 721 +
.../ent/identityadoptiondecision_update.go | 532 +
backend/ent/intercept/intercept.go | 120 +
backend/ent/migrate/schema.go | 216 +
backend/ent/mutation.go | 17621 ++++++++++------
backend/ent/pendingauthsession.go | 399 +
.../pendingauthsession/pendingauthsession.go | 279 +
backend/ent/pendingauthsession/where.go | 1262 ++
backend/ent/pendingauthsession_create.go | 1815 ++
backend/ent/pendingauthsession_delete.go | 88 +
backend/ent/pendingauthsession_query.go | 717 +
backend/ent/pendingauthsession_update.go | 1178 ++
backend/ent/predicate/predicate.go | 12 +
backend/ent/runtime/runtime.go | 266 +-
backend/ent/schema/auth_identity.go | 93 +
backend/ent/schema/auth_identity_channel.go | 72 +
.../ent/schema/auth_identity_schema_test.go | 124 +
.../ent/schema/identity_adoption_decision.go | 70 +
backend/ent/schema/pending_auth_session.go | 134 +
backend/ent/schema/user.go | 13 +
backend/ent/tx.go | 12 +
backend/ent/user.go | 79 +-
backend/ent/user/user.go | 88 +
backend/ent/user/where.go | 226 +
backend/ent/user_create.go | 290 +
backend/ent/user_query.go | 155 +-
backend/ent/user_update.go | 474 +
backend/internal/config/config.go | 15 +-
.../internal/handler/admin/setting_handler.go | 189 +-
...tting_handler_auth_source_defaults_test.go | 149 +
.../internal/handler/auth_linuxdo_oauth.go | 21 +-
.../handler/auth_linuxdo_oauth_test.go | 87 +
.../handler/auth_oauth_pending_flow.go | 325 +
.../handler/auth_oauth_pending_flow_test.go | 457 +
backend/internal/handler/auth_oidc_oauth.go | 21 +-
.../internal/handler/auth_oidc_oauth_test.go | 84 +
backend/internal/handler/auth_wechat_oauth.go | 618 +
.../handler/auth_wechat_oauth_test.go | 411 +
backend/internal/handler/dto/settings.go | 1 +
.../handler/payment_webhook_handler.go | 2 +-
.../handler/payment_webhook_handler_test.go | 34 +
backend/internal/handler/setting_handler.go | 1 +
backend/internal/handler/user_handler.go | 30 +-
backend/internal/handler/user_handler_test.go | 136 +
backend/internal/repository/api_key_repo.go | 6 +
.../auth_identity_migration_report.go | 148 +
.../repository/user_profile_identity_repo.go | 544 +
...ser_profile_identity_repo_contract_test.go | 428 +
backend/internal/repository/user_repo.go | 44 +
.../user_repo_sort_integration_test.go | 83 +
backend/internal/server/api_contract_test.go | 4 +-
backend/internal/server/routes/auth.go | 14 +
.../service/admin_service_apikey_test.go | 9 +
.../service/admin_service_delete_test.go | 12 +
.../service/auth_pending_identity_service.go | 326 +
.../auth_pending_identity_service_test.go | 224 +
backend/internal/service/auth_service.go | 164 +-
.../auth_service_identity_sync_test.go | 153 +
.../auth_service_pending_oauth_test.go | 146 -
backend/internal/service/domain_constants.go | 26 +
.../service/openai_account_scheduler.go | 77 +-
.../service/openai_account_scheduler_test.go | 430 +-
...enai_account_scheduler_ws_snapshot_test.go | 7 +-
.../service/payment_config_service.go | 96 +-
.../service/payment_config_service_test.go | 186 +
.../service/payment_resume_service.go | 248 +
.../service/payment_resume_service_test.go | 240 +
backend/internal/service/payment_service.go | 40 +-
backend/internal/service/setting_service.go | 256 +-
...tting_service_auth_source_defaults_test.go | 136 +
backend/internal/service/settings_view.go | 1 +
backend/internal/service/user.go | 34 +-
backend/internal/service/user_service.go | 153 +-
backend/internal/service/user_service_test.go | 180 +-
.../108_auth_identity_foundation_core.sql | 141 +
.../109_auth_identity_compat_backfill.sql | 125 +
...nding_auth_and_provider_default_grants.sql | 60 +
...11_payment_routing_and_scheduler_flags.sql | 8 +
.../api/__tests__/auth-oauth-adoption.spec.ts | 60 +
.../settings.authSourceDefaults.spec.ts | 118 +
frontend/src/api/admin/settings.ts | 127 +
frontend/src/api/auth.ts | 41 +-
.../components/auth/WechatOAuthSection.vue | 53 +
.../auth/__tests__/WechatOAuthSection.spec.ts | 74 +
.../components/payment/PaymentStatusPanel.vue | 29 +-
.../payment/__tests__/paymentFlow.spec.ts | 163 +
.../src/components/payment/paymentFlow.ts | 197 +
.../src/router/__tests__/wechat-route.spec.ts | 55 +
frontend/src/router/index.ts | 9 +
frontend/src/stores/app.ts | 1 +
frontend/src/types/index.ts | 1 +
frontend/src/views/admin/SettingsView.vue | 497 +-
.../src/views/auth/LinuxDoCallbackView.vue | 259 +-
frontend/src/views/auth/LoginView.vue | 10 +-
frontend/src/views/auth/OidcCallbackView.vue | 263 +-
frontend/src/views/auth/RegisterView.vue | 10 +-
.../src/views/auth/WechatCallbackView.vue | 361 +
.../__tests__/LinuxDoCallbackView.spec.ts | 180 +
.../auth/__tests__/OidcCallbackView.spec.ts | 191 +
.../auth/__tests__/WechatCallbackView.spec.ts | 241 +
frontend/src/views/user/PaymentView.vue | 229 +-
123 files changed, 40062 insertions(+), 7235 deletions(-)
create mode 100644 backend/ent/authidentity.go
create mode 100644 backend/ent/authidentity/authidentity.go
create mode 100644 backend/ent/authidentity/where.go
create mode 100644 backend/ent/authidentity_create.go
create mode 100644 backend/ent/authidentity_delete.go
create mode 100644 backend/ent/authidentity_query.go
create mode 100644 backend/ent/authidentity_update.go
create mode 100644 backend/ent/authidentitychannel.go
create mode 100644 backend/ent/authidentitychannel/authidentitychannel.go
create mode 100644 backend/ent/authidentitychannel/where.go
create mode 100644 backend/ent/authidentitychannel_create.go
create mode 100644 backend/ent/authidentitychannel_delete.go
create mode 100644 backend/ent/authidentitychannel_query.go
create mode 100644 backend/ent/authidentitychannel_update.go
create mode 100644 backend/ent/identityadoptiondecision.go
create mode 100644 backend/ent/identityadoptiondecision/identityadoptiondecision.go
create mode 100644 backend/ent/identityadoptiondecision/where.go
create mode 100644 backend/ent/identityadoptiondecision_create.go
create mode 100644 backend/ent/identityadoptiondecision_delete.go
create mode 100644 backend/ent/identityadoptiondecision_query.go
create mode 100644 backend/ent/identityadoptiondecision_update.go
create mode 100644 backend/ent/pendingauthsession.go
create mode 100644 backend/ent/pendingauthsession/pendingauthsession.go
create mode 100644 backend/ent/pendingauthsession/where.go
create mode 100644 backend/ent/pendingauthsession_create.go
create mode 100644 backend/ent/pendingauthsession_delete.go
create mode 100644 backend/ent/pendingauthsession_query.go
create mode 100644 backend/ent/pendingauthsession_update.go
create mode 100644 backend/ent/schema/auth_identity.go
create mode 100644 backend/ent/schema/auth_identity_channel.go
create mode 100644 backend/ent/schema/auth_identity_schema_test.go
create mode 100644 backend/ent/schema/identity_adoption_decision.go
create mode 100644 backend/ent/schema/pending_auth_session.go
create mode 100644 backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go
create mode 100644 backend/internal/handler/auth_wechat_oauth.go
create mode 100644 backend/internal/handler/auth_wechat_oauth_test.go
create mode 100644 backend/internal/handler/user_handler_test.go
create mode 100644 backend/internal/repository/auth_identity_migration_report.go
create mode 100644 backend/internal/repository/user_profile_identity_repo.go
create mode 100644 backend/internal/repository/user_profile_identity_repo_contract_test.go
create mode 100644 backend/internal/service/auth_pending_identity_service.go
create mode 100644 backend/internal/service/auth_pending_identity_service_test.go
create mode 100644 backend/internal/service/auth_service_identity_sync_test.go
delete mode 100644 backend/internal/service/auth_service_pending_oauth_test.go
create mode 100644 backend/internal/service/payment_resume_service.go
create mode 100644 backend/internal/service/payment_resume_service_test.go
create mode 100644 backend/internal/service/setting_service_auth_source_defaults_test.go
create mode 100644 backend/migrations/108_auth_identity_foundation_core.sql
create mode 100644 backend/migrations/109_auth_identity_compat_backfill.sql
create mode 100644 backend/migrations/110_pending_auth_and_provider_default_grants.sql
create mode 100644 backend/migrations/111_payment_routing_and_scheduler_flags.sql
create mode 100644 frontend/src/api/__tests__/auth-oauth-adoption.spec.ts
create mode 100644 frontend/src/api/__tests__/settings.authSourceDefaults.spec.ts
create mode 100644 frontend/src/components/auth/WechatOAuthSection.vue
create mode 100644 frontend/src/components/auth/__tests__/WechatOAuthSection.spec.ts
create mode 100644 frontend/src/components/payment/__tests__/paymentFlow.spec.ts
create mode 100644 frontend/src/components/payment/paymentFlow.ts
create mode 100644 frontend/src/router/__tests__/wechat-route.spec.ts
create mode 100644 frontend/src/views/auth/WechatCallbackView.vue
create mode 100644 frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts
create mode 100644 frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts
create mode 100644 frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts
diff --git a/backend/ent/authidentity.go b/backend/ent/authidentity.go
new file mode 100644
index 00000000..5ccfcf19
--- /dev/null
+++ b/backend/ent/authidentity.go
@@ -0,0 +1,266 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "encoding/json"
+ "fmt"
+ "strings"
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/sql"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// AuthIdentity is the model entity for the AuthIdentity schema.
+type AuthIdentity struct {
+ config `json:"-"`
+ // ID of the ent.
+ ID int64 `json:"id,omitempty"`
+ // CreatedAt holds the value of the "created_at" field.
+ CreatedAt time.Time `json:"created_at,omitempty"`
+ // UpdatedAt holds the value of the "updated_at" field.
+ UpdatedAt time.Time `json:"updated_at,omitempty"`
+ // UserID holds the value of the "user_id" field.
+ UserID int64 `json:"user_id,omitempty"`
+ // ProviderType holds the value of the "provider_type" field.
+ ProviderType string `json:"provider_type,omitempty"`
+ // ProviderKey holds the value of the "provider_key" field.
+ ProviderKey string `json:"provider_key,omitempty"`
+ // ProviderSubject holds the value of the "provider_subject" field.
+ ProviderSubject string `json:"provider_subject,omitempty"`
+ // VerifiedAt holds the value of the "verified_at" field.
+ VerifiedAt *time.Time `json:"verified_at,omitempty"`
+ // Issuer holds the value of the "issuer" field.
+ Issuer *string `json:"issuer,omitempty"`
+ // Metadata holds the value of the "metadata" field.
+ Metadata map[string]interface{} `json:"metadata,omitempty"`
+ // Edges holds the relations/edges for other nodes in the graph.
+ // The values are being populated by the AuthIdentityQuery when eager-loading is set.
+ Edges AuthIdentityEdges `json:"edges"`
+ selectValues sql.SelectValues
+}
+
+// AuthIdentityEdges holds the relations/edges for other nodes in the graph.
+type AuthIdentityEdges struct {
+ // User holds the value of the user edge.
+ User *User `json:"user,omitempty"`
+ // Channels holds the value of the channels edge.
+ Channels []*AuthIdentityChannel `json:"channels,omitempty"`
+ // AdoptionDecisions holds the value of the adoption_decisions edge.
+ AdoptionDecisions []*IdentityAdoptionDecision `json:"adoption_decisions,omitempty"`
+ // loadedTypes holds the information for reporting if a
+ // type was loaded (or requested) in eager-loading or not.
+ loadedTypes [3]bool
+}
+
+// UserOrErr returns the User value or an error if the edge
+// was not loaded in eager-loading, or loaded but was not found.
+func (e AuthIdentityEdges) UserOrErr() (*User, error) {
+ if e.User != nil {
+ return e.User, nil
+ } else if e.loadedTypes[0] {
+ return nil, &NotFoundError{label: user.Label}
+ }
+ return nil, &NotLoadedError{edge: "user"}
+}
+
+// ChannelsOrErr returns the Channels value or an error if the edge
+// was not loaded in eager-loading.
+func (e AuthIdentityEdges) ChannelsOrErr() ([]*AuthIdentityChannel, error) {
+ if e.loadedTypes[1] {
+ return e.Channels, nil
+ }
+ return nil, &NotLoadedError{edge: "channels"}
+}
+
+// AdoptionDecisionsOrErr returns the AdoptionDecisions value or an error if the edge
+// was not loaded in eager-loading.
+func (e AuthIdentityEdges) AdoptionDecisionsOrErr() ([]*IdentityAdoptionDecision, error) {
+ if e.loadedTypes[2] {
+ return e.AdoptionDecisions, nil
+ }
+ return nil, &NotLoadedError{edge: "adoption_decisions"}
+}
+
+// scanValues returns the types for scanning values from sql.Rows.
+func (*AuthIdentity) scanValues(columns []string) ([]any, error) {
+ values := make([]any, len(columns))
+ for i := range columns {
+ switch columns[i] {
+ case authidentity.FieldMetadata:
+ values[i] = new([]byte)
+ case authidentity.FieldID, authidentity.FieldUserID:
+ values[i] = new(sql.NullInt64)
+ case authidentity.FieldProviderType, authidentity.FieldProviderKey, authidentity.FieldProviderSubject, authidentity.FieldIssuer:
+ values[i] = new(sql.NullString)
+ case authidentity.FieldCreatedAt, authidentity.FieldUpdatedAt, authidentity.FieldVerifiedAt:
+ values[i] = new(sql.NullTime)
+ default:
+ values[i] = new(sql.UnknownType)
+ }
+ }
+ return values, nil
+}
+
+// assignValues assigns the values that were returned from sql.Rows (after scanning)
+// to the AuthIdentity fields.
+func (_m *AuthIdentity) assignValues(columns []string, values []any) error {
+ if m, n := len(values), len(columns); m < n {
+ return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
+ }
+ for i := range columns {
+ switch columns[i] {
+ case authidentity.FieldID:
+ value, ok := values[i].(*sql.NullInt64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field id", value)
+ }
+ _m.ID = int64(value.Int64)
+ case authidentity.FieldCreatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field created_at", values[i])
+ } else if value.Valid {
+ _m.CreatedAt = value.Time
+ }
+ case authidentity.FieldUpdatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field updated_at", values[i])
+ } else if value.Valid {
+ _m.UpdatedAt = value.Time
+ }
+ case authidentity.FieldUserID:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field user_id", values[i])
+ } else if value.Valid {
+ _m.UserID = value.Int64
+ }
+ case authidentity.FieldProviderType:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_type", values[i])
+ } else if value.Valid {
+ _m.ProviderType = value.String
+ }
+ case authidentity.FieldProviderKey:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_key", values[i])
+ } else if value.Valid {
+ _m.ProviderKey = value.String
+ }
+ case authidentity.FieldProviderSubject:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_subject", values[i])
+ } else if value.Valid {
+ _m.ProviderSubject = value.String
+ }
+ case authidentity.FieldVerifiedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field verified_at", values[i])
+ } else if value.Valid {
+ _m.VerifiedAt = new(time.Time)
+ *_m.VerifiedAt = value.Time
+ }
+ case authidentity.FieldIssuer:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field issuer", values[i])
+ } else if value.Valid {
+ _m.Issuer = new(string)
+ *_m.Issuer = value.String
+ }
+ case authidentity.FieldMetadata:
+ if value, ok := values[i].(*[]byte); !ok {
+ return fmt.Errorf("unexpected type %T for field metadata", values[i])
+ } else if value != nil && len(*value) > 0 {
+ if err := json.Unmarshal(*value, &_m.Metadata); err != nil {
+ return fmt.Errorf("unmarshal field metadata: %w", err)
+ }
+ }
+ default:
+ _m.selectValues.Set(columns[i], values[i])
+ }
+ }
+ return nil
+}
+
+// Value returns the ent.Value that was dynamically selected and assigned to the AuthIdentity.
+// This includes values selected through modifiers, order, etc.
+func (_m *AuthIdentity) Value(name string) (ent.Value, error) {
+ return _m.selectValues.Get(name)
+}
+
+// QueryUser queries the "user" edge of the AuthIdentity entity.
+func (_m *AuthIdentity) QueryUser() *UserQuery {
+ return NewAuthIdentityClient(_m.config).QueryUser(_m)
+}
+
+// QueryChannels queries the "channels" edge of the AuthIdentity entity.
+func (_m *AuthIdentity) QueryChannels() *AuthIdentityChannelQuery {
+ return NewAuthIdentityClient(_m.config).QueryChannels(_m)
+}
+
+// QueryAdoptionDecisions queries the "adoption_decisions" edge of the AuthIdentity entity.
+func (_m *AuthIdentity) QueryAdoptionDecisions() *IdentityAdoptionDecisionQuery {
+ return NewAuthIdentityClient(_m.config).QueryAdoptionDecisions(_m)
+}
+
+// Update returns a builder for updating this AuthIdentity.
+// Note that you need to call AuthIdentity.Unwrap() before calling this method if this AuthIdentity
+// was returned from a transaction, and the transaction was committed or rolled back.
+func (_m *AuthIdentity) Update() *AuthIdentityUpdateOne {
+ return NewAuthIdentityClient(_m.config).UpdateOne(_m)
+}
+
+// Unwrap unwraps the AuthIdentity entity that was returned from a transaction after it was closed,
+// so that all future queries will be executed through the driver which created the transaction.
+func (_m *AuthIdentity) Unwrap() *AuthIdentity {
+ _tx, ok := _m.config.driver.(*txDriver)
+ if !ok {
+ panic("ent: AuthIdentity is not a transactional entity")
+ }
+ _m.config.driver = _tx.drv
+ return _m
+}
+
+// String implements the fmt.Stringer.
+func (_m *AuthIdentity) String() string {
+ var builder strings.Builder
+ builder.WriteString("AuthIdentity(")
+ builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
+ builder.WriteString("created_at=")
+ builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("updated_at=")
+ builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("user_id=")
+ builder.WriteString(fmt.Sprintf("%v", _m.UserID))
+ builder.WriteString(", ")
+ builder.WriteString("provider_type=")
+ builder.WriteString(_m.ProviderType)
+ builder.WriteString(", ")
+ builder.WriteString("provider_key=")
+ builder.WriteString(_m.ProviderKey)
+ builder.WriteString(", ")
+ builder.WriteString("provider_subject=")
+ builder.WriteString(_m.ProviderSubject)
+ builder.WriteString(", ")
+ if v := _m.VerifiedAt; v != nil {
+ builder.WriteString("verified_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
+ if v := _m.Issuer; v != nil {
+ builder.WriteString("issuer=")
+ builder.WriteString(*v)
+ }
+ builder.WriteString(", ")
+ builder.WriteString("metadata=")
+ builder.WriteString(fmt.Sprintf("%v", _m.Metadata))
+ builder.WriteByte(')')
+ return builder.String()
+}
+
+// AuthIdentities is a parsable slice of AuthIdentity.
+type AuthIdentities []*AuthIdentity
diff --git a/backend/ent/authidentity/authidentity.go b/backend/ent/authidentity/authidentity.go
new file mode 100644
index 00000000..c90be759
--- /dev/null
+++ b/backend/ent/authidentity/authidentity.go
@@ -0,0 +1,209 @@
+// Code generated by ent, DO NOT EDIT.
+
+package authidentity
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+)
+
+const (
+ // Label holds the string label denoting the authidentity type in the database.
+ Label = "auth_identity"
+ // FieldID holds the string denoting the id field in the database.
+ FieldID = "id"
+ // FieldCreatedAt holds the string denoting the created_at field in the database.
+ FieldCreatedAt = "created_at"
+ // FieldUpdatedAt holds the string denoting the updated_at field in the database.
+ FieldUpdatedAt = "updated_at"
+ // FieldUserID holds the string denoting the user_id field in the database.
+ FieldUserID = "user_id"
+ // FieldProviderType holds the string denoting the provider_type field in the database.
+ FieldProviderType = "provider_type"
+ // FieldProviderKey holds the string denoting the provider_key field in the database.
+ FieldProviderKey = "provider_key"
+ // FieldProviderSubject holds the string denoting the provider_subject field in the database.
+ FieldProviderSubject = "provider_subject"
+ // FieldVerifiedAt holds the string denoting the verified_at field in the database.
+ FieldVerifiedAt = "verified_at"
+ // FieldIssuer holds the string denoting the issuer field in the database.
+ FieldIssuer = "issuer"
+ // FieldMetadata holds the string denoting the metadata field in the database.
+ FieldMetadata = "metadata"
+ // EdgeUser holds the string denoting the user edge name in mutations.
+ EdgeUser = "user"
+ // EdgeChannels holds the string denoting the channels edge name in mutations.
+ EdgeChannels = "channels"
+ // EdgeAdoptionDecisions holds the string denoting the adoption_decisions edge name in mutations.
+ EdgeAdoptionDecisions = "adoption_decisions"
+ // Table holds the table name of the authidentity in the database.
+ Table = "auth_identities"
+ // UserTable is the table that holds the user relation/edge.
+ UserTable = "auth_identities"
+ // UserInverseTable is the table name for the User entity.
+ // It exists in this package in order to avoid circular dependency with the "user" package.
+ UserInverseTable = "users"
+ // UserColumn is the table column denoting the user relation/edge.
+ UserColumn = "user_id"
+ // ChannelsTable is the table that holds the channels relation/edge.
+ ChannelsTable = "auth_identity_channels"
+ // ChannelsInverseTable is the table name for the AuthIdentityChannel entity.
+ // It exists in this package in order to avoid circular dependency with the "authidentitychannel" package.
+ ChannelsInverseTable = "auth_identity_channels"
+ // ChannelsColumn is the table column denoting the channels relation/edge.
+ ChannelsColumn = "identity_id"
+ // AdoptionDecisionsTable is the table that holds the adoption_decisions relation/edge.
+ AdoptionDecisionsTable = "identity_adoption_decisions"
+ // AdoptionDecisionsInverseTable is the table name for the IdentityAdoptionDecision entity.
+ // It exists in this package in order to avoid circular dependency with the "identityadoptiondecision" package.
+ AdoptionDecisionsInverseTable = "identity_adoption_decisions"
+ // AdoptionDecisionsColumn is the table column denoting the adoption_decisions relation/edge.
+ AdoptionDecisionsColumn = "identity_id"
+)
+
+// Columns holds all SQL columns for authidentity fields.
+var Columns = []string{
+ FieldID,
+ FieldCreatedAt,
+ FieldUpdatedAt,
+ FieldUserID,
+ FieldProviderType,
+ FieldProviderKey,
+ FieldProviderSubject,
+ FieldVerifiedAt,
+ FieldIssuer,
+ FieldMetadata,
+}
+
+// ValidColumn reports if the column name is valid (part of the table columns).
+func ValidColumn(column string) bool {
+ for i := range Columns {
+ if column == Columns[i] {
+ return true
+ }
+ }
+ return false
+}
+
+var (
+ // DefaultCreatedAt holds the default value on creation for the "created_at" field.
+ DefaultCreatedAt func() time.Time
+ // DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
+ DefaultUpdatedAt func() time.Time
+ // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
+ UpdateDefaultUpdatedAt func() time.Time
+ // ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save.
+ ProviderTypeValidator func(string) error
+ // ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
+ ProviderKeyValidator func(string) error
+ // ProviderSubjectValidator is a validator for the "provider_subject" field. It is called by the builders before save.
+ ProviderSubjectValidator func(string) error
+ // DefaultMetadata holds the default value on creation for the "metadata" field.
+ DefaultMetadata func() map[string]interface{}
+)
+
+// OrderOption defines the ordering options for the AuthIdentity queries.
+type OrderOption func(*sql.Selector)
+
+// ByID orders the results by the id field.
+func ByID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldID, opts...).ToFunc()
+}
+
+// ByCreatedAt orders the results by the created_at field.
+func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
+}
+
+// ByUpdatedAt orders the results by the updated_at field.
+func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
+}
+
+// ByUserID orders the results by the user_id field.
+func ByUserID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldUserID, opts...).ToFunc()
+}
+
+// ByProviderType orders the results by the provider_type field.
+func ByProviderType(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProviderType, opts...).ToFunc()
+}
+
+// ByProviderKey orders the results by the provider_key field.
+func ByProviderKey(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProviderKey, opts...).ToFunc()
+}
+
+// ByProviderSubject orders the results by the provider_subject field.
+func ByProviderSubject(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProviderSubject, opts...).ToFunc()
+}
+
+// ByVerifiedAt orders the results by the verified_at field.
+func ByVerifiedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldVerifiedAt, opts...).ToFunc()
+}
+
+// ByIssuer orders the results by the issuer field.
+func ByIssuer(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldIssuer, opts...).ToFunc()
+}
+
+// ByUserField orders the results by user field.
+func ByUserField(field string, opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newUserStep(), sql.OrderByField(field, opts...))
+ }
+}
+
+// ByChannelsCount orders the results by channels count.
+func ByChannelsCount(opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborsCount(s, newChannelsStep(), opts...)
+ }
+}
+
+// ByChannels orders the results by channels terms.
+func ByChannels(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newChannelsStep(), append([]sql.OrderTerm{term}, terms...)...)
+ }
+}
+
+// ByAdoptionDecisionsCount orders the results by adoption_decisions count.
+func ByAdoptionDecisionsCount(opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborsCount(s, newAdoptionDecisionsStep(), opts...)
+ }
+}
+
+// ByAdoptionDecisions orders the results by adoption_decisions terms.
+func ByAdoptionDecisions(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newAdoptionDecisionsStep(), append([]sql.OrderTerm{term}, terms...)...)
+ }
+}
+func newUserStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(UserInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn),
+ )
+}
+func newChannelsStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(ChannelsInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, ChannelsTable, ChannelsColumn),
+ )
+}
+func newAdoptionDecisionsStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(AdoptionDecisionsInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, AdoptionDecisionsTable, AdoptionDecisionsColumn),
+ )
+}
diff --git a/backend/ent/authidentity/where.go b/backend/ent/authidentity/where.go
new file mode 100644
index 00000000..3dbf3178
--- /dev/null
+++ b/backend/ent/authidentity/where.go
@@ -0,0 +1,600 @@
+// Code generated by ent, DO NOT EDIT.
+
+package authidentity
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ID filters vertices based on their ID field.
+func ID(id int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldID, id))
+}
+
+// IDEQ applies the EQ predicate on the ID field.
+func IDEQ(id int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldID, id))
+}
+
+// IDNEQ applies the NEQ predicate on the ID field.
+func IDNEQ(id int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNEQ(FieldID, id))
+}
+
+// IDIn applies the In predicate on the ID field.
+func IDIn(ids ...int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIn(FieldID, ids...))
+}
+
+// IDNotIn applies the NotIn predicate on the ID field.
+func IDNotIn(ids ...int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotIn(FieldID, ids...))
+}
+
+// IDGT applies the GT predicate on the ID field.
+func IDGT(id int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGT(FieldID, id))
+}
+
+// IDGTE applies the GTE predicate on the ID field.
+func IDGTE(id int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGTE(FieldID, id))
+}
+
+// IDLT applies the LT predicate on the ID field.
+func IDLT(id int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLT(FieldID, id))
+}
+
+// IDLTE applies the LTE predicate on the ID field.
+func IDLTE(id int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLTE(FieldID, id))
+}
+
+// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
+func CreatedAt(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
+func UpdatedAt(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// UserID applies equality check predicate on the "user_id" field. It's identical to UserIDEQ.
+func UserID(v int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldUserID, v))
+}
+
+// ProviderType applies equality check predicate on the "provider_type" field. It's identical to ProviderTypeEQ.
+func ProviderType(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldProviderType, v))
+}
+
+// ProviderKey applies equality check predicate on the "provider_key" field. It's identical to ProviderKeyEQ.
+func ProviderKey(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldProviderKey, v))
+}
+
+// ProviderSubject applies equality check predicate on the "provider_subject" field. It's identical to ProviderSubjectEQ.
+func ProviderSubject(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldProviderSubject, v))
+}
+
+// VerifiedAt applies equality check predicate on the "verified_at" field. It's identical to VerifiedAtEQ.
+func VerifiedAt(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldVerifiedAt, v))
+}
+
+// Issuer applies equality check predicate on the "issuer" field. It's identical to IssuerEQ.
+func Issuer(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldIssuer, v))
+}
+
+// CreatedAtEQ applies the EQ predicate on the "created_at" field.
+func CreatedAtEQ(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
+func CreatedAtNEQ(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtIn applies the In predicate on the "created_at" field.
+func CreatedAtIn(vs ...time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
+func CreatedAtNotIn(vs ...time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtGT applies the GT predicate on the "created_at" field.
+func CreatedAtGT(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGT(FieldCreatedAt, v))
+}
+
+// CreatedAtGTE applies the GTE predicate on the "created_at" field.
+func CreatedAtGTE(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGTE(FieldCreatedAt, v))
+}
+
+// CreatedAtLT applies the LT predicate on the "created_at" field.
+func CreatedAtLT(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLT(FieldCreatedAt, v))
+}
+
+// CreatedAtLTE applies the LTE predicate on the "created_at" field.
+func CreatedAtLTE(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLTE(FieldCreatedAt, v))
+}
+
+// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
+func UpdatedAtEQ(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
+func UpdatedAtNEQ(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtIn applies the In predicate on the "updated_at" field.
+func UpdatedAtIn(vs ...time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
+func UpdatedAtNotIn(vs ...time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtGT applies the GT predicate on the "updated_at" field.
+func UpdatedAtGT(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
+func UpdatedAtGTE(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGTE(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLT applies the LT predicate on the "updated_at" field.
+func UpdatedAtLT(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
+func UpdatedAtLTE(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLTE(FieldUpdatedAt, v))
+}
+
+// UserIDEQ applies the EQ predicate on the "user_id" field.
+func UserIDEQ(v int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldUserID, v))
+}
+
+// UserIDNEQ applies the NEQ predicate on the "user_id" field.
+func UserIDNEQ(v int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNEQ(FieldUserID, v))
+}
+
+// UserIDIn applies the In predicate on the "user_id" field.
+func UserIDIn(vs ...int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIn(FieldUserID, vs...))
+}
+
+// UserIDNotIn applies the NotIn predicate on the "user_id" field.
+func UserIDNotIn(vs ...int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotIn(FieldUserID, vs...))
+}
+
+// ProviderTypeEQ applies the EQ predicate on the "provider_type" field.
+func ProviderTypeEQ(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldProviderType, v))
+}
+
+// ProviderTypeNEQ applies the NEQ predicate on the "provider_type" field.
+func ProviderTypeNEQ(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNEQ(FieldProviderType, v))
+}
+
+// ProviderTypeIn applies the In predicate on the "provider_type" field.
+func ProviderTypeIn(vs ...string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIn(FieldProviderType, vs...))
+}
+
+// ProviderTypeNotIn applies the NotIn predicate on the "provider_type" field.
+func ProviderTypeNotIn(vs ...string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotIn(FieldProviderType, vs...))
+}
+
+// ProviderTypeGT applies the GT predicate on the "provider_type" field.
+func ProviderTypeGT(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGT(FieldProviderType, v))
+}
+
+// ProviderTypeGTE applies the GTE predicate on the "provider_type" field.
+func ProviderTypeGTE(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGTE(FieldProviderType, v))
+}
+
+// ProviderTypeLT applies the LT predicate on the "provider_type" field.
+func ProviderTypeLT(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLT(FieldProviderType, v))
+}
+
+// ProviderTypeLTE applies the LTE predicate on the "provider_type" field.
+func ProviderTypeLTE(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLTE(FieldProviderType, v))
+}
+
+// ProviderTypeContains applies the Contains predicate on the "provider_type" field.
+func ProviderTypeContains(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldContains(FieldProviderType, v))
+}
+
+// ProviderTypeHasPrefix applies the HasPrefix predicate on the "provider_type" field.
+func ProviderTypeHasPrefix(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldHasPrefix(FieldProviderType, v))
+}
+
+// ProviderTypeHasSuffix applies the HasSuffix predicate on the "provider_type" field.
+func ProviderTypeHasSuffix(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldHasSuffix(FieldProviderType, v))
+}
+
+// ProviderTypeEqualFold applies the EqualFold predicate on the "provider_type" field.
+func ProviderTypeEqualFold(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEqualFold(FieldProviderType, v))
+}
+
+// ProviderTypeContainsFold applies the ContainsFold predicate on the "provider_type" field.
+func ProviderTypeContainsFold(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldContainsFold(FieldProviderType, v))
+}
+
+// ProviderKeyEQ applies the EQ predicate on the "provider_key" field.
+func ProviderKeyEQ(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldProviderKey, v))
+}
+
+// ProviderKeyNEQ applies the NEQ predicate on the "provider_key" field.
+func ProviderKeyNEQ(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNEQ(FieldProviderKey, v))
+}
+
+// ProviderKeyIn applies the In predicate on the "provider_key" field.
+func ProviderKeyIn(vs ...string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIn(FieldProviderKey, vs...))
+}
+
+// ProviderKeyNotIn applies the NotIn predicate on the "provider_key" field.
+func ProviderKeyNotIn(vs ...string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotIn(FieldProviderKey, vs...))
+}
+
+// ProviderKeyGT applies the GT predicate on the "provider_key" field.
+func ProviderKeyGT(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGT(FieldProviderKey, v))
+}
+
+// ProviderKeyGTE applies the GTE predicate on the "provider_key" field.
+func ProviderKeyGTE(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGTE(FieldProviderKey, v))
+}
+
+// ProviderKeyLT applies the LT predicate on the "provider_key" field.
+func ProviderKeyLT(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLT(FieldProviderKey, v))
+}
+
+// ProviderKeyLTE applies the LTE predicate on the "provider_key" field.
+func ProviderKeyLTE(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLTE(FieldProviderKey, v))
+}
+
+// ProviderKeyContains applies the Contains predicate on the "provider_key" field.
+func ProviderKeyContains(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldContains(FieldProviderKey, v))
+}
+
+// ProviderKeyHasPrefix applies the HasPrefix predicate on the "provider_key" field.
+func ProviderKeyHasPrefix(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldHasPrefix(FieldProviderKey, v))
+}
+
+// ProviderKeyHasSuffix applies the HasSuffix predicate on the "provider_key" field.
+func ProviderKeyHasSuffix(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldHasSuffix(FieldProviderKey, v))
+}
+
+// ProviderKeyEqualFold applies the EqualFold predicate on the "provider_key" field.
+func ProviderKeyEqualFold(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEqualFold(FieldProviderKey, v))
+}
+
+// ProviderKeyContainsFold applies the ContainsFold predicate on the "provider_key" field.
+func ProviderKeyContainsFold(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldContainsFold(FieldProviderKey, v))
+}
+
+// ProviderSubjectEQ applies the EQ predicate on the "provider_subject" field.
+func ProviderSubjectEQ(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldProviderSubject, v))
+}
+
+// ProviderSubjectNEQ applies the NEQ predicate on the "provider_subject" field.
+func ProviderSubjectNEQ(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNEQ(FieldProviderSubject, v))
+}
+
+// ProviderSubjectIn applies the In predicate on the "provider_subject" field.
+func ProviderSubjectIn(vs ...string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIn(FieldProviderSubject, vs...))
+}
+
+// ProviderSubjectNotIn applies the NotIn predicate on the "provider_subject" field.
+func ProviderSubjectNotIn(vs ...string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotIn(FieldProviderSubject, vs...))
+}
+
+// ProviderSubjectGT applies the GT predicate on the "provider_subject" field.
+func ProviderSubjectGT(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGT(FieldProviderSubject, v))
+}
+
+// ProviderSubjectGTE applies the GTE predicate on the "provider_subject" field.
+func ProviderSubjectGTE(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGTE(FieldProviderSubject, v))
+}
+
+// ProviderSubjectLT applies the LT predicate on the "provider_subject" field.
+func ProviderSubjectLT(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLT(FieldProviderSubject, v))
+}
+
+// ProviderSubjectLTE applies the LTE predicate on the "provider_subject" field.
+func ProviderSubjectLTE(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLTE(FieldProviderSubject, v))
+}
+
+// ProviderSubjectContains applies the Contains predicate on the "provider_subject" field.
+func ProviderSubjectContains(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldContains(FieldProviderSubject, v))
+}
+
+// ProviderSubjectHasPrefix applies the HasPrefix predicate on the "provider_subject" field.
+func ProviderSubjectHasPrefix(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldHasPrefix(FieldProviderSubject, v))
+}
+
+// ProviderSubjectHasSuffix applies the HasSuffix predicate on the "provider_subject" field.
+func ProviderSubjectHasSuffix(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldHasSuffix(FieldProviderSubject, v))
+}
+
+// ProviderSubjectEqualFold applies the EqualFold predicate on the "provider_subject" field.
+func ProviderSubjectEqualFold(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEqualFold(FieldProviderSubject, v))
+}
+
+// ProviderSubjectContainsFold applies the ContainsFold predicate on the "provider_subject" field.
+func ProviderSubjectContainsFold(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldContainsFold(FieldProviderSubject, v))
+}
+
+// VerifiedAtEQ applies the EQ predicate on the "verified_at" field.
+func VerifiedAtEQ(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldVerifiedAt, v))
+}
+
+// VerifiedAtNEQ applies the NEQ predicate on the "verified_at" field.
+func VerifiedAtNEQ(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNEQ(FieldVerifiedAt, v))
+}
+
+// VerifiedAtIn applies the In predicate on the "verified_at" field.
+func VerifiedAtIn(vs ...time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIn(FieldVerifiedAt, vs...))
+}
+
+// VerifiedAtNotIn applies the NotIn predicate on the "verified_at" field.
+func VerifiedAtNotIn(vs ...time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotIn(FieldVerifiedAt, vs...))
+}
+
+// VerifiedAtGT applies the GT predicate on the "verified_at" field.
+func VerifiedAtGT(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGT(FieldVerifiedAt, v))
+}
+
+// VerifiedAtGTE applies the GTE predicate on the "verified_at" field.
+func VerifiedAtGTE(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGTE(FieldVerifiedAt, v))
+}
+
+// VerifiedAtLT applies the LT predicate on the "verified_at" field.
+func VerifiedAtLT(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLT(FieldVerifiedAt, v))
+}
+
+// VerifiedAtLTE applies the LTE predicate on the "verified_at" field.
+func VerifiedAtLTE(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLTE(FieldVerifiedAt, v))
+}
+
+// VerifiedAtIsNil applies the IsNil predicate on the "verified_at" field.
+func VerifiedAtIsNil() predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIsNull(FieldVerifiedAt))
+}
+
+// VerifiedAtNotNil applies the NotNil predicate on the "verified_at" field.
+func VerifiedAtNotNil() predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotNull(FieldVerifiedAt))
+}
+
+// IssuerEQ applies the EQ predicate on the "issuer" field.
+func IssuerEQ(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldIssuer, v))
+}
+
+// IssuerNEQ applies the NEQ predicate on the "issuer" field.
+func IssuerNEQ(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNEQ(FieldIssuer, v))
+}
+
+// IssuerIn applies the In predicate on the "issuer" field.
+func IssuerIn(vs ...string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIn(FieldIssuer, vs...))
+}
+
+// IssuerNotIn applies the NotIn predicate on the "issuer" field.
+func IssuerNotIn(vs ...string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotIn(FieldIssuer, vs...))
+}
+
+// IssuerGT applies the GT predicate on the "issuer" field.
+func IssuerGT(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGT(FieldIssuer, v))
+}
+
+// IssuerGTE applies the GTE predicate on the "issuer" field.
+func IssuerGTE(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGTE(FieldIssuer, v))
+}
+
+// IssuerLT applies the LT predicate on the "issuer" field.
+func IssuerLT(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLT(FieldIssuer, v))
+}
+
+// IssuerLTE applies the LTE predicate on the "issuer" field.
+func IssuerLTE(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLTE(FieldIssuer, v))
+}
+
+// IssuerContains applies the Contains predicate on the "issuer" field.
+func IssuerContains(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldContains(FieldIssuer, v))
+}
+
+// IssuerHasPrefix applies the HasPrefix predicate on the "issuer" field.
+func IssuerHasPrefix(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldHasPrefix(FieldIssuer, v))
+}
+
+// IssuerHasSuffix applies the HasSuffix predicate on the "issuer" field.
+func IssuerHasSuffix(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldHasSuffix(FieldIssuer, v))
+}
+
+// IssuerIsNil applies the IsNil predicate on the "issuer" field.
+func IssuerIsNil() predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIsNull(FieldIssuer))
+}
+
+// IssuerNotNil applies the NotNil predicate on the "issuer" field.
+func IssuerNotNil() predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotNull(FieldIssuer))
+}
+
+// IssuerEqualFold applies the EqualFold predicate on the "issuer" field.
+func IssuerEqualFold(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEqualFold(FieldIssuer, v))
+}
+
+// IssuerContainsFold applies the ContainsFold predicate on the "issuer" field.
+func IssuerContainsFold(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldContainsFold(FieldIssuer, v))
+}
+
+// HasUser applies the HasEdge predicate on the "user" edge.
+func HasUser() predicate.AuthIdentity {
+ return predicate.AuthIdentity(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasUserWith applies the HasEdge predicate on the "user" edge with a given conditions (other predicates).
+func HasUserWith(preds ...predicate.User) predicate.AuthIdentity {
+ return predicate.AuthIdentity(func(s *sql.Selector) {
+ step := newUserStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// HasChannels applies the HasEdge predicate on the "channels" edge.
+func HasChannels() predicate.AuthIdentity {
+ return predicate.AuthIdentity(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, ChannelsTable, ChannelsColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasChannelsWith applies the HasEdge predicate on the "channels" edge with a given conditions (other predicates).
+func HasChannelsWith(preds ...predicate.AuthIdentityChannel) predicate.AuthIdentity {
+ return predicate.AuthIdentity(func(s *sql.Selector) {
+ step := newChannelsStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// HasAdoptionDecisions applies the HasEdge predicate on the "adoption_decisions" edge.
+func HasAdoptionDecisions() predicate.AuthIdentity {
+ return predicate.AuthIdentity(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, AdoptionDecisionsTable, AdoptionDecisionsColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasAdoptionDecisionsWith applies the HasEdge predicate on the "adoption_decisions" edge with a given conditions (other predicates).
+func HasAdoptionDecisionsWith(preds ...predicate.IdentityAdoptionDecision) predicate.AuthIdentity {
+ return predicate.AuthIdentity(func(s *sql.Selector) {
+ step := newAdoptionDecisionsStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// And groups predicates with the AND operator between them.
+func And(predicates ...predicate.AuthIdentity) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.AndPredicates(predicates...))
+}
+
+// Or groups predicates with the OR operator between them.
+func Or(predicates ...predicate.AuthIdentity) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.OrPredicates(predicates...))
+}
+
+// Not applies the not operator on the given predicate.
+func Not(p predicate.AuthIdentity) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.NotPredicates(p))
+}
diff --git a/backend/ent/authidentity_create.go b/backend/ent/authidentity_create.go
new file mode 100644
index 00000000..e287705c
--- /dev/null
+++ b/backend/ent/authidentity_create.go
@@ -0,0 +1,1036 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// AuthIdentityCreate is the builder for creating a AuthIdentity entity.
+type AuthIdentityCreate struct {
+ config
+ mutation *AuthIdentityMutation
+ hooks []Hook
+ conflict []sql.ConflictOption
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (_c *AuthIdentityCreate) SetCreatedAt(v time.Time) *AuthIdentityCreate {
+ _c.mutation.SetCreatedAt(v)
+ return _c
+}
+
+// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
+func (_c *AuthIdentityCreate) SetNillableCreatedAt(v *time.Time) *AuthIdentityCreate {
+ if v != nil {
+ _c.SetCreatedAt(*v)
+ }
+ return _c
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_c *AuthIdentityCreate) SetUpdatedAt(v time.Time) *AuthIdentityCreate {
+ _c.mutation.SetUpdatedAt(v)
+ return _c
+}
+
+// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil.
+func (_c *AuthIdentityCreate) SetNillableUpdatedAt(v *time.Time) *AuthIdentityCreate {
+ if v != nil {
+ _c.SetUpdatedAt(*v)
+ }
+ return _c
+}
+
+// SetUserID sets the "user_id" field.
+func (_c *AuthIdentityCreate) SetUserID(v int64) *AuthIdentityCreate {
+ _c.mutation.SetUserID(v)
+ return _c
+}
+
+// SetProviderType sets the "provider_type" field.
+func (_c *AuthIdentityCreate) SetProviderType(v string) *AuthIdentityCreate {
+ _c.mutation.SetProviderType(v)
+ return _c
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_c *AuthIdentityCreate) SetProviderKey(v string) *AuthIdentityCreate {
+ _c.mutation.SetProviderKey(v)
+ return _c
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (_c *AuthIdentityCreate) SetProviderSubject(v string) *AuthIdentityCreate {
+ _c.mutation.SetProviderSubject(v)
+ return _c
+}
+
+// SetVerifiedAt sets the "verified_at" field.
+func (_c *AuthIdentityCreate) SetVerifiedAt(v time.Time) *AuthIdentityCreate {
+ _c.mutation.SetVerifiedAt(v)
+ return _c
+}
+
+// SetNillableVerifiedAt sets the "verified_at" field if the given value is not nil.
+func (_c *AuthIdentityCreate) SetNillableVerifiedAt(v *time.Time) *AuthIdentityCreate {
+ if v != nil {
+ _c.SetVerifiedAt(*v)
+ }
+ return _c
+}
+
+// SetIssuer sets the "issuer" field.
+func (_c *AuthIdentityCreate) SetIssuer(v string) *AuthIdentityCreate {
+ _c.mutation.SetIssuer(v)
+ return _c
+}
+
+// SetNillableIssuer sets the "issuer" field if the given value is not nil.
+func (_c *AuthIdentityCreate) SetNillableIssuer(v *string) *AuthIdentityCreate {
+ if v != nil {
+ _c.SetIssuer(*v)
+ }
+ return _c
+}
+
+// SetMetadata sets the "metadata" field.
+func (_c *AuthIdentityCreate) SetMetadata(v map[string]interface{}) *AuthIdentityCreate {
+ _c.mutation.SetMetadata(v)
+ return _c
+}
+
+// SetUser sets the "user" edge to the User entity.
+func (_c *AuthIdentityCreate) SetUser(v *User) *AuthIdentityCreate {
+ return _c.SetUserID(v.ID)
+}
+
+// AddChannelIDs adds the "channels" edge to the AuthIdentityChannel entity by IDs.
+func (_c *AuthIdentityCreate) AddChannelIDs(ids ...int64) *AuthIdentityCreate {
+ _c.mutation.AddChannelIDs(ids...)
+ return _c
+}
+
+// AddChannels adds the "channels" edges to the AuthIdentityChannel entity.
+func (_c *AuthIdentityCreate) AddChannels(v ...*AuthIdentityChannel) *AuthIdentityCreate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _c.AddChannelIDs(ids...)
+}
+
+// AddAdoptionDecisionIDs adds the "adoption_decisions" edge to the IdentityAdoptionDecision entity by IDs.
+func (_c *AuthIdentityCreate) AddAdoptionDecisionIDs(ids ...int64) *AuthIdentityCreate {
+ _c.mutation.AddAdoptionDecisionIDs(ids...)
+ return _c
+}
+
+// AddAdoptionDecisions adds the "adoption_decisions" edges to the IdentityAdoptionDecision entity.
+func (_c *AuthIdentityCreate) AddAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityCreate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _c.AddAdoptionDecisionIDs(ids...)
+}
+
+// Mutation returns the AuthIdentityMutation object of the builder.
+func (_c *AuthIdentityCreate) Mutation() *AuthIdentityMutation {
+ return _c.mutation
+}
+
+// Save creates the AuthIdentity in the database.
+func (_c *AuthIdentityCreate) Save(ctx context.Context) (*AuthIdentity, error) {
+ _c.defaults()
+ return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
+}
+
+// SaveX calls Save and panics if Save returns an error.
+func (_c *AuthIdentityCreate) SaveX(ctx context.Context) *AuthIdentity {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *AuthIdentityCreate) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *AuthIdentityCreate) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_c *AuthIdentityCreate) defaults() {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ v := authidentity.DefaultCreatedAt()
+ _c.mutation.SetCreatedAt(v)
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ v := authidentity.DefaultUpdatedAt()
+ _c.mutation.SetUpdatedAt(v)
+ }
+ if _, ok := _c.mutation.Metadata(); !ok {
+ v := authidentity.DefaultMetadata()
+ _c.mutation.SetMetadata(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_c *AuthIdentityCreate) check() error {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "AuthIdentity.created_at"`)}
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "AuthIdentity.updated_at"`)}
+ }
+ if _, ok := _c.mutation.UserID(); !ok {
+ return &ValidationError{Name: "user_id", err: errors.New(`ent: missing required field "AuthIdentity.user_id"`)}
+ }
+ if _, ok := _c.mutation.ProviderType(); !ok {
+ return &ValidationError{Name: "provider_type", err: errors.New(`ent: missing required field "AuthIdentity.provider_type"`)}
+ }
+ if v, ok := _c.mutation.ProviderType(); ok {
+ if err := authidentity.ProviderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_type": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ProviderKey(); !ok {
+ return &ValidationError{Name: "provider_key", err: errors.New(`ent: missing required field "AuthIdentity.provider_key"`)}
+ }
+ if v, ok := _c.mutation.ProviderKey(); ok {
+ if err := authidentity.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_key": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ProviderSubject(); !ok {
+ return &ValidationError{Name: "provider_subject", err: errors.New(`ent: missing required field "AuthIdentity.provider_subject"`)}
+ }
+ if v, ok := _c.mutation.ProviderSubject(); ok {
+ if err := authidentity.ProviderSubjectValidator(v); err != nil {
+ return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_subject": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.Metadata(); !ok {
+ return &ValidationError{Name: "metadata", err: errors.New(`ent: missing required field "AuthIdentity.metadata"`)}
+ }
+ if len(_c.mutation.UserIDs()) == 0 {
+ return &ValidationError{Name: "user", err: errors.New(`ent: missing required edge "AuthIdentity.user"`)}
+ }
+ return nil
+}
+
+func (_c *AuthIdentityCreate) sqlSave(ctx context.Context) (*AuthIdentity, error) {
+ if err := _c.check(); err != nil {
+ return nil, err
+ }
+ _node, _spec := _c.createSpec()
+ if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ id := _spec.ID.Value.(int64)
+ _node.ID = int64(id)
+ _c.mutation.id = &_node.ID
+ _c.mutation.done = true
+ return _node, nil
+}
+
+func (_c *AuthIdentityCreate) createSpec() (*AuthIdentity, *sqlgraph.CreateSpec) {
+ var (
+ _node = &AuthIdentity{config: _c.config}
+ _spec = sqlgraph.NewCreateSpec(authidentity.Table, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64))
+ )
+ _spec.OnConflict = _c.conflict
+ if value, ok := _c.mutation.CreatedAt(); ok {
+ _spec.SetField(authidentity.FieldCreatedAt, field.TypeTime, value)
+ _node.CreatedAt = value
+ }
+ if value, ok := _c.mutation.UpdatedAt(); ok {
+ _spec.SetField(authidentity.FieldUpdatedAt, field.TypeTime, value)
+ _node.UpdatedAt = value
+ }
+ if value, ok := _c.mutation.ProviderType(); ok {
+ _spec.SetField(authidentity.FieldProviderType, field.TypeString, value)
+ _node.ProviderType = value
+ }
+ if value, ok := _c.mutation.ProviderKey(); ok {
+ _spec.SetField(authidentity.FieldProviderKey, field.TypeString, value)
+ _node.ProviderKey = value
+ }
+ if value, ok := _c.mutation.ProviderSubject(); ok {
+ _spec.SetField(authidentity.FieldProviderSubject, field.TypeString, value)
+ _node.ProviderSubject = value
+ }
+ if value, ok := _c.mutation.VerifiedAt(); ok {
+ _spec.SetField(authidentity.FieldVerifiedAt, field.TypeTime, value)
+ _node.VerifiedAt = &value
+ }
+ if value, ok := _c.mutation.Issuer(); ok {
+ _spec.SetField(authidentity.FieldIssuer, field.TypeString, value)
+ _node.Issuer = &value
+ }
+ if value, ok := _c.mutation.Metadata(); ok {
+ _spec.SetField(authidentity.FieldMetadata, field.TypeJSON, value)
+ _node.Metadata = value
+ }
+ if nodes := _c.mutation.UserIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentity.UserTable,
+ Columns: []string{authidentity.UserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _node.UserID = nodes[0]
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ if nodes := _c.mutation.ChannelsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.ChannelsTable,
+ Columns: []string{authidentity.ChannelsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ if nodes := _c.mutation.AdoptionDecisionsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.AdoptionDecisionsTable,
+ Columns: []string{authidentity.AdoptionDecisionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ return _node, _spec
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.AuthIdentity.Create().
+// SetCreatedAt(v).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.AuthIdentityUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *AuthIdentityCreate) OnConflict(opts ...sql.ConflictOption) *AuthIdentityUpsertOne {
+ _c.conflict = opts
+ return &AuthIdentityUpsertOne{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.AuthIdentity.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *AuthIdentityCreate) OnConflictColumns(columns ...string) *AuthIdentityUpsertOne {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &AuthIdentityUpsertOne{
+ create: _c,
+ }
+}
+
+type (
+ // AuthIdentityUpsertOne is the builder for "upsert"-ing
+ // one AuthIdentity node.
+ AuthIdentityUpsertOne struct {
+ create *AuthIdentityCreate
+ }
+
+ // AuthIdentityUpsert is the "OnConflict" setter.
+ AuthIdentityUpsert struct {
+ *sql.UpdateSet
+ }
+)
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *AuthIdentityUpsert) SetUpdatedAt(v time.Time) *AuthIdentityUpsert {
+ u.Set(authidentity.FieldUpdatedAt, v)
+ return u
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *AuthIdentityUpsert) UpdateUpdatedAt() *AuthIdentityUpsert {
+ u.SetExcluded(authidentity.FieldUpdatedAt)
+ return u
+}
+
+// SetUserID sets the "user_id" field.
+func (u *AuthIdentityUpsert) SetUserID(v int64) *AuthIdentityUpsert {
+ u.Set(authidentity.FieldUserID, v)
+ return u
+}
+
+// UpdateUserID sets the "user_id" field to the value that was provided on create.
+func (u *AuthIdentityUpsert) UpdateUserID() *AuthIdentityUpsert {
+ u.SetExcluded(authidentity.FieldUserID)
+ return u
+}
+
+// SetProviderType sets the "provider_type" field.
+func (u *AuthIdentityUpsert) SetProviderType(v string) *AuthIdentityUpsert {
+ u.Set(authidentity.FieldProviderType, v)
+ return u
+}
+
+// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
+func (u *AuthIdentityUpsert) UpdateProviderType() *AuthIdentityUpsert {
+ u.SetExcluded(authidentity.FieldProviderType)
+ return u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *AuthIdentityUpsert) SetProviderKey(v string) *AuthIdentityUpsert {
+ u.Set(authidentity.FieldProviderKey, v)
+ return u
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *AuthIdentityUpsert) UpdateProviderKey() *AuthIdentityUpsert {
+ u.SetExcluded(authidentity.FieldProviderKey)
+ return u
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (u *AuthIdentityUpsert) SetProviderSubject(v string) *AuthIdentityUpsert {
+ u.Set(authidentity.FieldProviderSubject, v)
+ return u
+}
+
+// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create.
+func (u *AuthIdentityUpsert) UpdateProviderSubject() *AuthIdentityUpsert {
+ u.SetExcluded(authidentity.FieldProviderSubject)
+ return u
+}
+
+// SetVerifiedAt sets the "verified_at" field.
+func (u *AuthIdentityUpsert) SetVerifiedAt(v time.Time) *AuthIdentityUpsert {
+ u.Set(authidentity.FieldVerifiedAt, v)
+ return u
+}
+
+// UpdateVerifiedAt sets the "verified_at" field to the value that was provided on create.
+func (u *AuthIdentityUpsert) UpdateVerifiedAt() *AuthIdentityUpsert {
+ u.SetExcluded(authidentity.FieldVerifiedAt)
+ return u
+}
+
+// ClearVerifiedAt clears the value of the "verified_at" field.
+func (u *AuthIdentityUpsert) ClearVerifiedAt() *AuthIdentityUpsert {
+ u.SetNull(authidentity.FieldVerifiedAt)
+ return u
+}
+
+// SetIssuer sets the "issuer" field.
+func (u *AuthIdentityUpsert) SetIssuer(v string) *AuthIdentityUpsert {
+ u.Set(authidentity.FieldIssuer, v)
+ return u
+}
+
+// UpdateIssuer sets the "issuer" field to the value that was provided on create.
+func (u *AuthIdentityUpsert) UpdateIssuer() *AuthIdentityUpsert {
+ u.SetExcluded(authidentity.FieldIssuer)
+ return u
+}
+
+// ClearIssuer clears the value of the "issuer" field.
+func (u *AuthIdentityUpsert) ClearIssuer() *AuthIdentityUpsert {
+ u.SetNull(authidentity.FieldIssuer)
+ return u
+}
+
+// SetMetadata sets the "metadata" field.
+func (u *AuthIdentityUpsert) SetMetadata(v map[string]interface{}) *AuthIdentityUpsert {
+ u.Set(authidentity.FieldMetadata, v)
+ return u
+}
+
+// UpdateMetadata sets the "metadata" field to the value that was provided on create.
+func (u *AuthIdentityUpsert) UpdateMetadata() *AuthIdentityUpsert {
+ u.SetExcluded(authidentity.FieldMetadata)
+ return u
+}
+
+// UpdateNewValues updates the mutable fields using the new values that were set on create.
+// Using this option is equivalent to using:
+//
+// client.AuthIdentity.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *AuthIdentityUpsertOne) UpdateNewValues() *AuthIdentityUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ if _, exists := u.create.mutation.CreatedAt(); exists {
+ s.SetIgnore(authidentity.FieldCreatedAt)
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.AuthIdentity.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *AuthIdentityUpsertOne) Ignore() *AuthIdentityUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *AuthIdentityUpsertOne) DoNothing() *AuthIdentityUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the AuthIdentityCreate.OnConflict
+// documentation for more info.
+func (u *AuthIdentityUpsertOne) Update(set func(*AuthIdentityUpsert)) *AuthIdentityUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&AuthIdentityUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *AuthIdentityUpsertOne) SetUpdatedAt(v time.Time) *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *AuthIdentityUpsertOne) UpdateUpdatedAt() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetUserID sets the "user_id" field.
+func (u *AuthIdentityUpsertOne) SetUserID(v int64) *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetUserID(v)
+ })
+}
+
+// UpdateUserID sets the "user_id" field to the value that was provided on create.
+func (u *AuthIdentityUpsertOne) UpdateUserID() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateUserID()
+ })
+}
+
+// SetProviderType sets the "provider_type" field.
+func (u *AuthIdentityUpsertOne) SetProviderType(v string) *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetProviderType(v)
+ })
+}
+
+// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
+func (u *AuthIdentityUpsertOne) UpdateProviderType() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateProviderType()
+ })
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *AuthIdentityUpsertOne) SetProviderKey(v string) *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetProviderKey(v)
+ })
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *AuthIdentityUpsertOne) UpdateProviderKey() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateProviderKey()
+ })
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (u *AuthIdentityUpsertOne) SetProviderSubject(v string) *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetProviderSubject(v)
+ })
+}
+
+// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create.
+func (u *AuthIdentityUpsertOne) UpdateProviderSubject() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateProviderSubject()
+ })
+}
+
+// SetVerifiedAt sets the "verified_at" field.
+func (u *AuthIdentityUpsertOne) SetVerifiedAt(v time.Time) *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetVerifiedAt(v)
+ })
+}
+
+// UpdateVerifiedAt sets the "verified_at" field to the value that was provided on create.
+func (u *AuthIdentityUpsertOne) UpdateVerifiedAt() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateVerifiedAt()
+ })
+}
+
+// ClearVerifiedAt clears the value of the "verified_at" field.
+func (u *AuthIdentityUpsertOne) ClearVerifiedAt() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.ClearVerifiedAt()
+ })
+}
+
+// SetIssuer sets the "issuer" field.
+func (u *AuthIdentityUpsertOne) SetIssuer(v string) *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetIssuer(v)
+ })
+}
+
+// UpdateIssuer sets the "issuer" field to the value that was provided on create.
+func (u *AuthIdentityUpsertOne) UpdateIssuer() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateIssuer()
+ })
+}
+
+// ClearIssuer clears the value of the "issuer" field.
+func (u *AuthIdentityUpsertOne) ClearIssuer() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.ClearIssuer()
+ })
+}
+
+// SetMetadata sets the "metadata" field.
+func (u *AuthIdentityUpsertOne) SetMetadata(v map[string]interface{}) *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetMetadata(v)
+ })
+}
+
+// UpdateMetadata sets the "metadata" field to the value that was provided on create.
+func (u *AuthIdentityUpsertOne) UpdateMetadata() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateMetadata()
+ })
+}
+
+// Exec executes the query.
+func (u *AuthIdentityUpsertOne) Exec(ctx context.Context) error {
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for AuthIdentityCreate.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *AuthIdentityUpsertOne) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// Exec executes the UPSERT query and returns the inserted/updated ID.
+func (u *AuthIdentityUpsertOne) ID(ctx context.Context) (id int64, err error) {
+ node, err := u.create.Save(ctx)
+ if err != nil {
+ return id, err
+ }
+ return node.ID, nil
+}
+
+// IDX is like ID, but panics if an error occurs.
+func (u *AuthIdentityUpsertOne) IDX(ctx context.Context) int64 {
+ id, err := u.ID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// AuthIdentityCreateBulk is the builder for creating many AuthIdentity entities in bulk.
+type AuthIdentityCreateBulk struct {
+ config
+ err error
+ builders []*AuthIdentityCreate
+ conflict []sql.ConflictOption
+}
+
+// Save creates the AuthIdentity entities in the database.
+func (_c *AuthIdentityCreateBulk) Save(ctx context.Context) ([]*AuthIdentity, error) {
+ if _c.err != nil {
+ return nil, _c.err
+ }
+ specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
+ nodes := make([]*AuthIdentity, len(_c.builders))
+ mutators := make([]Mutator, len(_c.builders))
+ for i := range _c.builders {
+ func(i int, root context.Context) {
+ builder := _c.builders[i]
+ builder.defaults()
+ var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
+ mutation, ok := m.(*AuthIdentityMutation)
+ if !ok {
+ return nil, fmt.Errorf("unexpected mutation type %T", m)
+ }
+ if err := builder.check(); err != nil {
+ return nil, err
+ }
+ builder.mutation = mutation
+ var err error
+ nodes[i], specs[i] = builder.createSpec()
+ if i < len(mutators)-1 {
+ _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
+ } else {
+ spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
+ spec.OnConflict = _c.conflict
+ // Invoke the actual operation on the latest mutation in the chain.
+ if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ }
+ }
+ if err != nil {
+ return nil, err
+ }
+ mutation.id = &nodes[i].ID
+ if specs[i].ID.Value != nil {
+ id := specs[i].ID.Value.(int64)
+ nodes[i].ID = int64(id)
+ }
+ mutation.done = true
+ return nodes[i], nil
+ })
+ for i := len(builder.hooks) - 1; i >= 0; i-- {
+ mut = builder.hooks[i](mut)
+ }
+ mutators[i] = mut
+ }(i, ctx)
+ }
+ if len(mutators) > 0 {
+ if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_c *AuthIdentityCreateBulk) SaveX(ctx context.Context) []*AuthIdentity {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *AuthIdentityCreateBulk) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *AuthIdentityCreateBulk) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.AuthIdentity.CreateBulk(builders...).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.AuthIdentityUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *AuthIdentityCreateBulk) OnConflict(opts ...sql.ConflictOption) *AuthIdentityUpsertBulk {
+ _c.conflict = opts
+ return &AuthIdentityUpsertBulk{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.AuthIdentity.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *AuthIdentityCreateBulk) OnConflictColumns(columns ...string) *AuthIdentityUpsertBulk {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &AuthIdentityUpsertBulk{
+ create: _c,
+ }
+}
+
+// AuthIdentityUpsertBulk is the builder for "upsert"-ing
+// a bulk of AuthIdentity nodes.
+type AuthIdentityUpsertBulk struct {
+ create *AuthIdentityCreateBulk
+}
+
+// UpdateNewValues updates the mutable fields using the new values that
+// were set on create. Using this option is equivalent to using:
+//
+// client.AuthIdentity.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *AuthIdentityUpsertBulk) UpdateNewValues() *AuthIdentityUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ for _, b := range u.create.builders {
+ if _, exists := b.mutation.CreatedAt(); exists {
+ s.SetIgnore(authidentity.FieldCreatedAt)
+ }
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.AuthIdentity.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *AuthIdentityUpsertBulk) Ignore() *AuthIdentityUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *AuthIdentityUpsertBulk) DoNothing() *AuthIdentityUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the AuthIdentityCreateBulk.OnConflict
+// documentation for more info.
+func (u *AuthIdentityUpsertBulk) Update(set func(*AuthIdentityUpsert)) *AuthIdentityUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&AuthIdentityUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *AuthIdentityUpsertBulk) SetUpdatedAt(v time.Time) *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *AuthIdentityUpsertBulk) UpdateUpdatedAt() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetUserID sets the "user_id" field.
+func (u *AuthIdentityUpsertBulk) SetUserID(v int64) *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetUserID(v)
+ })
+}
+
+// UpdateUserID sets the "user_id" field to the value that was provided on create.
+func (u *AuthIdentityUpsertBulk) UpdateUserID() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateUserID()
+ })
+}
+
+// SetProviderType sets the "provider_type" field.
+func (u *AuthIdentityUpsertBulk) SetProviderType(v string) *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetProviderType(v)
+ })
+}
+
+// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
+func (u *AuthIdentityUpsertBulk) UpdateProviderType() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateProviderType()
+ })
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *AuthIdentityUpsertBulk) SetProviderKey(v string) *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetProviderKey(v)
+ })
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *AuthIdentityUpsertBulk) UpdateProviderKey() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateProviderKey()
+ })
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (u *AuthIdentityUpsertBulk) SetProviderSubject(v string) *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetProviderSubject(v)
+ })
+}
+
+// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create.
+func (u *AuthIdentityUpsertBulk) UpdateProviderSubject() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateProviderSubject()
+ })
+}
+
+// SetVerifiedAt sets the "verified_at" field.
+func (u *AuthIdentityUpsertBulk) SetVerifiedAt(v time.Time) *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetVerifiedAt(v)
+ })
+}
+
+// UpdateVerifiedAt sets the "verified_at" field to the value that was provided on create.
+func (u *AuthIdentityUpsertBulk) UpdateVerifiedAt() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateVerifiedAt()
+ })
+}
+
+// ClearVerifiedAt clears the value of the "verified_at" field.
+func (u *AuthIdentityUpsertBulk) ClearVerifiedAt() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.ClearVerifiedAt()
+ })
+}
+
+// SetIssuer sets the "issuer" field.
+func (u *AuthIdentityUpsertBulk) SetIssuer(v string) *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetIssuer(v)
+ })
+}
+
+// UpdateIssuer sets the "issuer" field to the value that was provided on create.
+func (u *AuthIdentityUpsertBulk) UpdateIssuer() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateIssuer()
+ })
+}
+
+// ClearIssuer clears the value of the "issuer" field.
+func (u *AuthIdentityUpsertBulk) ClearIssuer() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.ClearIssuer()
+ })
+}
+
+// SetMetadata sets the "metadata" field.
+func (u *AuthIdentityUpsertBulk) SetMetadata(v map[string]interface{}) *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetMetadata(v)
+ })
+}
+
+// UpdateMetadata sets the "metadata" field to the value that was provided on create.
+func (u *AuthIdentityUpsertBulk) UpdateMetadata() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateMetadata()
+ })
+}
+
+// Exec executes the query.
+func (u *AuthIdentityUpsertBulk) Exec(ctx context.Context) error {
+ if u.create.err != nil {
+ return u.create.err
+ }
+ for i, b := range u.create.builders {
+ if len(b.conflict) != 0 {
+ return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the AuthIdentityCreateBulk instead", i)
+ }
+ }
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for AuthIdentityCreateBulk.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *AuthIdentityUpsertBulk) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/authidentity_delete.go b/backend/ent/authidentity_delete.go
new file mode 100644
index 00000000..4f1f6f3c
--- /dev/null
+++ b/backend/ent/authidentity_delete.go
@@ -0,0 +1,88 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// AuthIdentityDelete is the builder for deleting a AuthIdentity entity.
+type AuthIdentityDelete struct {
+ config
+ hooks []Hook
+ mutation *AuthIdentityMutation
+}
+
+// Where appends a list predicates to the AuthIdentityDelete builder.
+func (_d *AuthIdentityDelete) Where(ps ...predicate.AuthIdentity) *AuthIdentityDelete {
+ _d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query and returns how many vertices were deleted.
+func (_d *AuthIdentityDelete) Exec(ctx context.Context) (int, error) {
+ return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *AuthIdentityDelete) ExecX(ctx context.Context) int {
+ n, err := _d.Exec(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return n
+}
+
+func (_d *AuthIdentityDelete) sqlExec(ctx context.Context) (int, error) {
+ _spec := sqlgraph.NewDeleteSpec(authidentity.Table, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64))
+ if ps := _d.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
+ if err != nil && sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ _d.mutation.done = true
+ return affected, err
+}
+
+// AuthIdentityDeleteOne is the builder for deleting a single AuthIdentity entity.
+type AuthIdentityDeleteOne struct {
+ _d *AuthIdentityDelete
+}
+
+// Where appends a list predicates to the AuthIdentityDelete builder.
+func (_d *AuthIdentityDeleteOne) Where(ps ...predicate.AuthIdentity) *AuthIdentityDeleteOne {
+ _d._d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query.
+func (_d *AuthIdentityDeleteOne) Exec(ctx context.Context) error {
+ n, err := _d._d.Exec(ctx)
+ switch {
+ case err != nil:
+ return err
+ case n == 0:
+ return &NotFoundError{authidentity.Label}
+ default:
+ return nil
+ }
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *AuthIdentityDeleteOne) ExecX(ctx context.Context) {
+ if err := _d.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/authidentity_query.go b/backend/ent/authidentity_query.go
new file mode 100644
index 00000000..ff27ef3c
--- /dev/null
+++ b/backend/ent/authidentity_query.go
@@ -0,0 +1,797 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "database/sql/driver"
+ "fmt"
+ "math"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// AuthIdentityQuery is the builder for querying AuthIdentity entities.
+type AuthIdentityQuery struct {
+ config
+ ctx *QueryContext
+ order []authidentity.OrderOption
+ inters []Interceptor
+ predicates []predicate.AuthIdentity
+ withUser *UserQuery
+ withChannels *AuthIdentityChannelQuery
+ withAdoptionDecisions *IdentityAdoptionDecisionQuery
+ modifiers []func(*sql.Selector)
+ // intermediate query (i.e. traversal path).
+ sql *sql.Selector
+ path func(context.Context) (*sql.Selector, error)
+}
+
+// Where adds a new predicate for the AuthIdentityQuery builder.
+func (_q *AuthIdentityQuery) Where(ps ...predicate.AuthIdentity) *AuthIdentityQuery {
+ _q.predicates = append(_q.predicates, ps...)
+ return _q
+}
+
+// Limit the number of records to be returned by this query.
+func (_q *AuthIdentityQuery) Limit(limit int) *AuthIdentityQuery {
+ _q.ctx.Limit = &limit
+ return _q
+}
+
+// Offset to start from.
+func (_q *AuthIdentityQuery) Offset(offset int) *AuthIdentityQuery {
+ _q.ctx.Offset = &offset
+ return _q
+}
+
+// Unique configures the query builder to filter duplicate records on query.
+// By default, unique is set to true, and can be disabled using this method.
+func (_q *AuthIdentityQuery) Unique(unique bool) *AuthIdentityQuery {
+ _q.ctx.Unique = &unique
+ return _q
+}
+
+// Order specifies how the records should be ordered.
+func (_q *AuthIdentityQuery) Order(o ...authidentity.OrderOption) *AuthIdentityQuery {
+ _q.order = append(_q.order, o...)
+ return _q
+}
+
+// QueryUser chains the current query on the "user" edge.
+func (_q *AuthIdentityQuery) QueryUser() *UserQuery {
+ query := (&UserClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(authidentity.Table, authidentity.FieldID, selector),
+ sqlgraph.To(user.Table, user.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, authidentity.UserTable, authidentity.UserColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// QueryChannels chains the current query on the "channels" edge.
+func (_q *AuthIdentityQuery) QueryChannels() *AuthIdentityChannelQuery {
+ query := (&AuthIdentityChannelClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(authidentity.Table, authidentity.FieldID, selector),
+ sqlgraph.To(authidentitychannel.Table, authidentitychannel.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, authidentity.ChannelsTable, authidentity.ChannelsColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// QueryAdoptionDecisions chains the current query on the "adoption_decisions" edge.
+func (_q *AuthIdentityQuery) QueryAdoptionDecisions() *IdentityAdoptionDecisionQuery {
+ query := (&IdentityAdoptionDecisionClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(authidentity.Table, authidentity.FieldID, selector),
+ sqlgraph.To(identityadoptiondecision.Table, identityadoptiondecision.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, authidentity.AdoptionDecisionsTable, authidentity.AdoptionDecisionsColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// First returns the first AuthIdentity entity from the query.
+// Returns a *NotFoundError when no AuthIdentity was found.
+func (_q *AuthIdentityQuery) First(ctx context.Context) (*AuthIdentity, error) {
+ nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
+ if err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nil, &NotFoundError{authidentity.Label}
+ }
+ return nodes[0], nil
+}
+
+// FirstX is like First, but panics if an error occurs.
+func (_q *AuthIdentityQuery) FirstX(ctx context.Context) *AuthIdentity {
+ node, err := _q.First(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return node
+}
+
+// FirstID returns the first AuthIdentity ID from the query.
+// Returns a *NotFoundError when no AuthIdentity ID was found.
+func (_q *AuthIdentityQuery) FirstID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
+ return
+ }
+ if len(ids) == 0 {
+ err = &NotFoundError{authidentity.Label}
+ return
+ }
+ return ids[0], nil
+}
+
+// FirstIDX is like FirstID, but panics if an error occurs.
+func (_q *AuthIdentityQuery) FirstIDX(ctx context.Context) int64 {
+ id, err := _q.FirstID(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return id
+}
+
+// Only returns a single AuthIdentity entity found by the query, ensuring it only returns one.
+// Returns a *NotSingularError when more than one AuthIdentity entity is found.
+// Returns a *NotFoundError when no AuthIdentity entities are found.
+func (_q *AuthIdentityQuery) Only(ctx context.Context) (*AuthIdentity, error) {
+ nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
+ if err != nil {
+ return nil, err
+ }
+ switch len(nodes) {
+ case 1:
+ return nodes[0], nil
+ case 0:
+ return nil, &NotFoundError{authidentity.Label}
+ default:
+ return nil, &NotSingularError{authidentity.Label}
+ }
+}
+
+// OnlyX is like Only, but panics if an error occurs.
+func (_q *AuthIdentityQuery) OnlyX(ctx context.Context) *AuthIdentity {
+ node, err := _q.Only(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// OnlyID is like Only, but returns the only AuthIdentity ID in the query.
+// Returns a *NotSingularError when more than one AuthIdentity ID is found.
+// Returns a *NotFoundError when no entities are found.
+func (_q *AuthIdentityQuery) OnlyID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
+ return
+ }
+ switch len(ids) {
+ case 1:
+ id = ids[0]
+ case 0:
+ err = &NotFoundError{authidentity.Label}
+ default:
+ err = &NotSingularError{authidentity.Label}
+ }
+ return
+}
+
+// OnlyIDX is like OnlyID, but panics if an error occurs.
+func (_q *AuthIdentityQuery) OnlyIDX(ctx context.Context) int64 {
+ id, err := _q.OnlyID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// All executes the query and returns a list of AuthIdentities.
+func (_q *AuthIdentityQuery) All(ctx context.Context) ([]*AuthIdentity, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ qr := querierAll[[]*AuthIdentity, *AuthIdentityQuery]()
+ return withInterceptors[[]*AuthIdentity](ctx, _q, qr, _q.inters)
+}
+
+// AllX is like All, but panics if an error occurs.
+func (_q *AuthIdentityQuery) AllX(ctx context.Context) []*AuthIdentity {
+ nodes, err := _q.All(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return nodes
+}
+
+// IDs executes the query and returns a list of AuthIdentity IDs.
+func (_q *AuthIdentityQuery) IDs(ctx context.Context) (ids []int64, err error) {
+ if _q.ctx.Unique == nil && _q.path != nil {
+ _q.Unique(true)
+ }
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
+ if err = _q.Select(authidentity.FieldID).Scan(ctx, &ids); err != nil {
+ return nil, err
+ }
+ return ids, nil
+}
+
+// IDsX is like IDs, but panics if an error occurs.
+func (_q *AuthIdentityQuery) IDsX(ctx context.Context) []int64 {
+ ids, err := _q.IDs(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return ids
+}
+
+// Count returns the count of the given query.
+func (_q *AuthIdentityQuery) Count(ctx context.Context) (int, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return 0, err
+ }
+ return withInterceptors[int](ctx, _q, querierCount[*AuthIdentityQuery](), _q.inters)
+}
+
+// CountX is like Count, but panics if an error occurs.
+func (_q *AuthIdentityQuery) CountX(ctx context.Context) int {
+ count, err := _q.Count(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return count
+}
+
+// Exist returns true if the query has elements in the graph.
+func (_q *AuthIdentityQuery) Exist(ctx context.Context) (bool, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
+ switch _, err := _q.FirstID(ctx); {
+ case IsNotFound(err):
+ return false, nil
+ case err != nil:
+ return false, fmt.Errorf("ent: check existence: %w", err)
+ default:
+ return true, nil
+ }
+}
+
+// ExistX is like Exist, but panics if an error occurs.
+func (_q *AuthIdentityQuery) ExistX(ctx context.Context) bool {
+ exist, err := _q.Exist(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return exist
+}
+
+// Clone returns a duplicate of the AuthIdentityQuery builder, including all associated steps. It can be
+// used to prepare common query builders and use them differently after the clone is made.
+func (_q *AuthIdentityQuery) Clone() *AuthIdentityQuery {
+ if _q == nil {
+ return nil
+ }
+ return &AuthIdentityQuery{
+ config: _q.config,
+ ctx: _q.ctx.Clone(),
+ order: append([]authidentity.OrderOption{}, _q.order...),
+ inters: append([]Interceptor{}, _q.inters...),
+ predicates: append([]predicate.AuthIdentity{}, _q.predicates...),
+ withUser: _q.withUser.Clone(),
+ withChannels: _q.withChannels.Clone(),
+ withAdoptionDecisions: _q.withAdoptionDecisions.Clone(),
+ // clone intermediate query.
+ sql: _q.sql.Clone(),
+ path: _q.path,
+ }
+}
+
+// WithUser tells the query-builder to eager-load the nodes that are connected to
+// the "user" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *AuthIdentityQuery) WithUser(opts ...func(*UserQuery)) *AuthIdentityQuery {
+ query := (&UserClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withUser = query
+ return _q
+}
+
+// WithChannels tells the query-builder to eager-load the nodes that are connected to
+// the "channels" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *AuthIdentityQuery) WithChannels(opts ...func(*AuthIdentityChannelQuery)) *AuthIdentityQuery {
+ query := (&AuthIdentityChannelClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withChannels = query
+ return _q
+}
+
+// WithAdoptionDecisions tells the query-builder to eager-load the nodes that are connected to
+// the "adoption_decisions" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *AuthIdentityQuery) WithAdoptionDecisions(opts ...func(*IdentityAdoptionDecisionQuery)) *AuthIdentityQuery {
+ query := (&IdentityAdoptionDecisionClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withAdoptionDecisions = query
+ return _q
+}
+
+// GroupBy is used to group vertices by one or more fields/columns.
+// It is often used with aggregate functions, like: count, max, mean, min, sum.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// Count int `json:"count,omitempty"`
+// }
+//
+// client.AuthIdentity.Query().
+// GroupBy(authidentity.FieldCreatedAt).
+// Aggregate(ent.Count()).
+// Scan(ctx, &v)
+func (_q *AuthIdentityQuery) GroupBy(field string, fields ...string) *AuthIdentityGroupBy {
+ _q.ctx.Fields = append([]string{field}, fields...)
+ grbuild := &AuthIdentityGroupBy{build: _q}
+ grbuild.flds = &_q.ctx.Fields
+ grbuild.label = authidentity.Label
+ grbuild.scan = grbuild.Scan
+ return grbuild
+}
+
+// Select allows the selection one or more fields/columns for the given query,
+// instead of selecting all fields in the entity.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// }
+//
+// client.AuthIdentity.Query().
+// Select(authidentity.FieldCreatedAt).
+// Scan(ctx, &v)
+func (_q *AuthIdentityQuery) Select(fields ...string) *AuthIdentitySelect {
+ _q.ctx.Fields = append(_q.ctx.Fields, fields...)
+ sbuild := &AuthIdentitySelect{AuthIdentityQuery: _q}
+ sbuild.label = authidentity.Label
+ sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
+ return sbuild
+}
+
+// Aggregate returns a AuthIdentitySelect configured with the given aggregations.
+func (_q *AuthIdentityQuery) Aggregate(fns ...AggregateFunc) *AuthIdentitySelect {
+ return _q.Select().Aggregate(fns...)
+}
+
+func (_q *AuthIdentityQuery) prepareQuery(ctx context.Context) error {
+ for _, inter := range _q.inters {
+ if inter == nil {
+ return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
+ }
+ if trv, ok := inter.(Traverser); ok {
+ if err := trv.Traverse(ctx, _q); err != nil {
+ return err
+ }
+ }
+ }
+ for _, f := range _q.ctx.Fields {
+ if !authidentity.ValidColumn(f) {
+ return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ }
+ if _q.path != nil {
+ prev, err := _q.path(ctx)
+ if err != nil {
+ return err
+ }
+ _q.sql = prev
+ }
+ return nil
+}
+
+func (_q *AuthIdentityQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*AuthIdentity, error) {
+ var (
+ nodes = []*AuthIdentity{}
+ _spec = _q.querySpec()
+ loadedTypes = [3]bool{
+ _q.withUser != nil,
+ _q.withChannels != nil,
+ _q.withAdoptionDecisions != nil,
+ }
+ )
+ _spec.ScanValues = func(columns []string) ([]any, error) {
+ return (*AuthIdentity).scanValues(nil, columns)
+ }
+ _spec.Assign = func(columns []string, values []any) error {
+ node := &AuthIdentity{config: _q.config}
+ nodes = append(nodes, node)
+ node.Edges.loadedTypes = loadedTypes
+ return node.assignValues(columns, values)
+ }
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ for i := range hooks {
+ hooks[i](ctx, _spec)
+ }
+ if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nodes, nil
+ }
+ if query := _q.withUser; query != nil {
+ if err := _q.loadUser(ctx, query, nodes, nil,
+ func(n *AuthIdentity, e *User) { n.Edges.User = e }); err != nil {
+ return nil, err
+ }
+ }
+ if query := _q.withChannels; query != nil {
+ if err := _q.loadChannels(ctx, query, nodes,
+ func(n *AuthIdentity) { n.Edges.Channels = []*AuthIdentityChannel{} },
+ func(n *AuthIdentity, e *AuthIdentityChannel) { n.Edges.Channels = append(n.Edges.Channels, e) }); err != nil {
+ return nil, err
+ }
+ }
+ if query := _q.withAdoptionDecisions; query != nil {
+ if err := _q.loadAdoptionDecisions(ctx, query, nodes,
+ func(n *AuthIdentity) { n.Edges.AdoptionDecisions = []*IdentityAdoptionDecision{} },
+ func(n *AuthIdentity, e *IdentityAdoptionDecision) {
+ n.Edges.AdoptionDecisions = append(n.Edges.AdoptionDecisions, e)
+ }); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+func (_q *AuthIdentityQuery) loadUser(ctx context.Context, query *UserQuery, nodes []*AuthIdentity, init func(*AuthIdentity), assign func(*AuthIdentity, *User)) error {
+ ids := make([]int64, 0, len(nodes))
+ nodeids := make(map[int64][]*AuthIdentity)
+ for i := range nodes {
+ fk := nodes[i].UserID
+ if _, ok := nodeids[fk]; !ok {
+ ids = append(ids, fk)
+ }
+ nodeids[fk] = append(nodeids[fk], nodes[i])
+ }
+ if len(ids) == 0 {
+ return nil
+ }
+ query.Where(user.IDIn(ids...))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ nodes, ok := nodeids[n.ID]
+ if !ok {
+ return fmt.Errorf(`unexpected foreign-key "user_id" returned %v`, n.ID)
+ }
+ for i := range nodes {
+ assign(nodes[i], n)
+ }
+ }
+ return nil
+}
+func (_q *AuthIdentityQuery) loadChannels(ctx context.Context, query *AuthIdentityChannelQuery, nodes []*AuthIdentity, init func(*AuthIdentity), assign func(*AuthIdentity, *AuthIdentityChannel)) error {
+ fks := make([]driver.Value, 0, len(nodes))
+ nodeids := make(map[int64]*AuthIdentity)
+ for i := range nodes {
+ fks = append(fks, nodes[i].ID)
+ nodeids[nodes[i].ID] = nodes[i]
+ if init != nil {
+ init(nodes[i])
+ }
+ }
+ if len(query.ctx.Fields) > 0 {
+ query.ctx.AppendFieldOnce(authidentitychannel.FieldIdentityID)
+ }
+ query.Where(predicate.AuthIdentityChannel(func(s *sql.Selector) {
+ s.Where(sql.InValues(s.C(authidentity.ChannelsColumn), fks...))
+ }))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ fk := n.IdentityID
+ node, ok := nodeids[fk]
+ if !ok {
+ return fmt.Errorf(`unexpected referenced foreign-key "identity_id" returned %v for node %v`, fk, n.ID)
+ }
+ assign(node, n)
+ }
+ return nil
+}
+func (_q *AuthIdentityQuery) loadAdoptionDecisions(ctx context.Context, query *IdentityAdoptionDecisionQuery, nodes []*AuthIdentity, init func(*AuthIdentity), assign func(*AuthIdentity, *IdentityAdoptionDecision)) error {
+ fks := make([]driver.Value, 0, len(nodes))
+ nodeids := make(map[int64]*AuthIdentity)
+ for i := range nodes {
+ fks = append(fks, nodes[i].ID)
+ nodeids[nodes[i].ID] = nodes[i]
+ if init != nil {
+ init(nodes[i])
+ }
+ }
+ if len(query.ctx.Fields) > 0 {
+ query.ctx.AppendFieldOnce(identityadoptiondecision.FieldIdentityID)
+ }
+ query.Where(predicate.IdentityAdoptionDecision(func(s *sql.Selector) {
+ s.Where(sql.InValues(s.C(authidentity.AdoptionDecisionsColumn), fks...))
+ }))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ fk := n.IdentityID
+ if fk == nil {
+ return fmt.Errorf(`foreign-key "identity_id" is nil for node %v`, n.ID)
+ }
+ node, ok := nodeids[*fk]
+ if !ok {
+ return fmt.Errorf(`unexpected referenced foreign-key "identity_id" returned %v for node %v`, *fk, n.ID)
+ }
+ assign(node, n)
+ }
+ return nil
+}
+
+func (_q *AuthIdentityQuery) sqlCount(ctx context.Context) (int, error) {
+ _spec := _q.querySpec()
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ _spec.Node.Columns = _q.ctx.Fields
+ if len(_q.ctx.Fields) > 0 {
+ _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
+ }
+ return sqlgraph.CountNodes(ctx, _q.driver, _spec)
+}
+
+func (_q *AuthIdentityQuery) querySpec() *sqlgraph.QuerySpec {
+ _spec := sqlgraph.NewQuerySpec(authidentity.Table, authidentity.Columns, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64))
+ _spec.From = _q.sql
+ if unique := _q.ctx.Unique; unique != nil {
+ _spec.Unique = *unique
+ } else if _q.path != nil {
+ _spec.Unique = true
+ }
+ if fields := _q.ctx.Fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, authidentity.FieldID)
+ for i := range fields {
+ if fields[i] != authidentity.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, fields[i])
+ }
+ }
+ if _q.withUser != nil {
+ _spec.Node.AddColumnOnce(authidentity.FieldUserID)
+ }
+ }
+ if ps := _q.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ _spec.Limit = *limit
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ _spec.Offset = *offset
+ }
+ if ps := _q.order; len(ps) > 0 {
+ _spec.Order = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ return _spec
+}
+
+func (_q *AuthIdentityQuery) sqlQuery(ctx context.Context) *sql.Selector {
+ builder := sql.Dialect(_q.driver.Dialect())
+ t1 := builder.Table(authidentity.Table)
+ columns := _q.ctx.Fields
+ if len(columns) == 0 {
+ columns = authidentity.Columns
+ }
+ selector := builder.Select(t1.Columns(columns...)...).From(t1)
+ if _q.sql != nil {
+ selector = _q.sql
+ selector.Select(selector.Columns(columns...)...)
+ }
+ if _q.ctx.Unique != nil && *_q.ctx.Unique {
+ selector.Distinct()
+ }
+ for _, m := range _q.modifiers {
+ m(selector)
+ }
+ for _, p := range _q.predicates {
+ p(selector)
+ }
+ for _, p := range _q.order {
+ p(selector)
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ // limit is mandatory for offset clause. We start
+ // with default value, and override it below if needed.
+ selector.Offset(*offset).Limit(math.MaxInt32)
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ selector.Limit(*limit)
+ }
+ return selector
+}
+
+// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
+// updated, deleted or "selected ... for update" by other sessions, until the transaction is
+// either committed or rolled-back.
+func (_q *AuthIdentityQuery) ForUpdate(opts ...sql.LockOption) *AuthIdentityQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForUpdate(opts...)
+ })
+ return _q
+}
+
+// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
+// on any rows that are read. Other sessions can read the rows, but cannot modify them
+// until your transaction commits.
+func (_q *AuthIdentityQuery) ForShare(opts ...sql.LockOption) *AuthIdentityQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForShare(opts...)
+ })
+ return _q
+}
+
+// AuthIdentityGroupBy is the group-by builder for AuthIdentity entities.
+type AuthIdentityGroupBy struct {
+ selector
+ build *AuthIdentityQuery
+}
+
+// Aggregate adds the given aggregation functions to the group-by query.
+func (_g *AuthIdentityGroupBy) Aggregate(fns ...AggregateFunc) *AuthIdentityGroupBy {
+ _g.fns = append(_g.fns, fns...)
+ return _g
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_g *AuthIdentityGroupBy) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
+ if err := _g.build.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*AuthIdentityQuery, *AuthIdentityGroupBy](ctx, _g.build, _g, _g.build.inters, v)
+}
+
+func (_g *AuthIdentityGroupBy) sqlScan(ctx context.Context, root *AuthIdentityQuery, v any) error {
+ selector := root.sqlQuery(ctx).Select()
+ aggregation := make([]string, 0, len(_g.fns))
+ for _, fn := range _g.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ if len(selector.SelectedColumns()) == 0 {
+ columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
+ for _, f := range *_g.flds {
+ columns = append(columns, selector.C(f))
+ }
+ columns = append(columns, aggregation...)
+ selector.Select(columns...)
+ }
+ selector.GroupBy(selector.Columns(*_g.flds...)...)
+ if err := selector.Err(); err != nil {
+ return err
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
+
+// AuthIdentitySelect is the builder for selecting fields of AuthIdentity entities.
+type AuthIdentitySelect struct {
+ *AuthIdentityQuery
+ selector
+}
+
+// Aggregate adds the given aggregation functions to the selector query.
+func (_s *AuthIdentitySelect) Aggregate(fns ...AggregateFunc) *AuthIdentitySelect {
+ _s.fns = append(_s.fns, fns...)
+ return _s
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_s *AuthIdentitySelect) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
+ if err := _s.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*AuthIdentityQuery, *AuthIdentitySelect](ctx, _s.AuthIdentityQuery, _s, _s.inters, v)
+}
+
+func (_s *AuthIdentitySelect) sqlScan(ctx context.Context, root *AuthIdentityQuery, v any) error {
+ selector := root.sqlQuery(ctx)
+ aggregation := make([]string, 0, len(_s.fns))
+ for _, fn := range _s.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ switch n := len(*_s.selector.flds); {
+ case n == 0 && len(aggregation) > 0:
+ selector.Select(aggregation...)
+ case n != 0 && len(aggregation) > 0:
+ selector.AppendSelect(aggregation...)
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _s.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
diff --git a/backend/ent/authidentity_update.go b/backend/ent/authidentity_update.go
new file mode 100644
index 00000000..c457470b
--- /dev/null
+++ b/backend/ent/authidentity_update.go
@@ -0,0 +1,923 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// AuthIdentityUpdate is the builder for updating AuthIdentity entities.
+type AuthIdentityUpdate struct {
+ config
+ hooks []Hook
+ mutation *AuthIdentityMutation
+}
+
+// Where appends a list predicates to the AuthIdentityUpdate builder.
+func (_u *AuthIdentityUpdate) Where(ps ...predicate.AuthIdentity) *AuthIdentityUpdate {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *AuthIdentityUpdate) SetUpdatedAt(v time.Time) *AuthIdentityUpdate {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetUserID sets the "user_id" field.
+func (_u *AuthIdentityUpdate) SetUserID(v int64) *AuthIdentityUpdate {
+ _u.mutation.SetUserID(v)
+ return _u
+}
+
+// SetNillableUserID sets the "user_id" field if the given value is not nil.
+func (_u *AuthIdentityUpdate) SetNillableUserID(v *int64) *AuthIdentityUpdate {
+ if v != nil {
+ _u.SetUserID(*v)
+ }
+ return _u
+}
+
+// SetProviderType sets the "provider_type" field.
+func (_u *AuthIdentityUpdate) SetProviderType(v string) *AuthIdentityUpdate {
+ _u.mutation.SetProviderType(v)
+ return _u
+}
+
+// SetNillableProviderType sets the "provider_type" field if the given value is not nil.
+func (_u *AuthIdentityUpdate) SetNillableProviderType(v *string) *AuthIdentityUpdate {
+ if v != nil {
+ _u.SetProviderType(*v)
+ }
+ return _u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_u *AuthIdentityUpdate) SetProviderKey(v string) *AuthIdentityUpdate {
+ _u.mutation.SetProviderKey(v)
+ return _u
+}
+
+// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
+func (_u *AuthIdentityUpdate) SetNillableProviderKey(v *string) *AuthIdentityUpdate {
+ if v != nil {
+ _u.SetProviderKey(*v)
+ }
+ return _u
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (_u *AuthIdentityUpdate) SetProviderSubject(v string) *AuthIdentityUpdate {
+ _u.mutation.SetProviderSubject(v)
+ return _u
+}
+
+// SetNillableProviderSubject sets the "provider_subject" field if the given value is not nil.
+func (_u *AuthIdentityUpdate) SetNillableProviderSubject(v *string) *AuthIdentityUpdate {
+ if v != nil {
+ _u.SetProviderSubject(*v)
+ }
+ return _u
+}
+
+// SetVerifiedAt sets the "verified_at" field.
+func (_u *AuthIdentityUpdate) SetVerifiedAt(v time.Time) *AuthIdentityUpdate {
+ _u.mutation.SetVerifiedAt(v)
+ return _u
+}
+
+// SetNillableVerifiedAt sets the "verified_at" field if the given value is not nil.
+func (_u *AuthIdentityUpdate) SetNillableVerifiedAt(v *time.Time) *AuthIdentityUpdate {
+ if v != nil {
+ _u.SetVerifiedAt(*v)
+ }
+ return _u
+}
+
+// ClearVerifiedAt clears the value of the "verified_at" field.
+func (_u *AuthIdentityUpdate) ClearVerifiedAt() *AuthIdentityUpdate {
+ _u.mutation.ClearVerifiedAt()
+ return _u
+}
+
+// SetIssuer sets the "issuer" field.
+func (_u *AuthIdentityUpdate) SetIssuer(v string) *AuthIdentityUpdate {
+ _u.mutation.SetIssuer(v)
+ return _u
+}
+
+// SetNillableIssuer sets the "issuer" field if the given value is not nil.
+func (_u *AuthIdentityUpdate) SetNillableIssuer(v *string) *AuthIdentityUpdate {
+ if v != nil {
+ _u.SetIssuer(*v)
+ }
+ return _u
+}
+
+// ClearIssuer clears the value of the "issuer" field.
+func (_u *AuthIdentityUpdate) ClearIssuer() *AuthIdentityUpdate {
+ _u.mutation.ClearIssuer()
+ return _u
+}
+
+// SetMetadata sets the "metadata" field.
+func (_u *AuthIdentityUpdate) SetMetadata(v map[string]interface{}) *AuthIdentityUpdate {
+ _u.mutation.SetMetadata(v)
+ return _u
+}
+
+// SetUser sets the "user" edge to the User entity.
+func (_u *AuthIdentityUpdate) SetUser(v *User) *AuthIdentityUpdate {
+ return _u.SetUserID(v.ID)
+}
+
+// AddChannelIDs adds the "channels" edge to the AuthIdentityChannel entity by IDs.
+func (_u *AuthIdentityUpdate) AddChannelIDs(ids ...int64) *AuthIdentityUpdate {
+ _u.mutation.AddChannelIDs(ids...)
+ return _u
+}
+
+// AddChannels adds the "channels" edges to the AuthIdentityChannel entity.
+func (_u *AuthIdentityUpdate) AddChannels(v ...*AuthIdentityChannel) *AuthIdentityUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddChannelIDs(ids...)
+}
+
+// AddAdoptionDecisionIDs adds the "adoption_decisions" edge to the IdentityAdoptionDecision entity by IDs.
+func (_u *AuthIdentityUpdate) AddAdoptionDecisionIDs(ids ...int64) *AuthIdentityUpdate {
+ _u.mutation.AddAdoptionDecisionIDs(ids...)
+ return _u
+}
+
+// AddAdoptionDecisions adds the "adoption_decisions" edges to the IdentityAdoptionDecision entity.
+func (_u *AuthIdentityUpdate) AddAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddAdoptionDecisionIDs(ids...)
+}
+
+// Mutation returns the AuthIdentityMutation object of the builder.
+func (_u *AuthIdentityUpdate) Mutation() *AuthIdentityMutation {
+ return _u.mutation
+}
+
+// ClearUser clears the "user" edge to the User entity.
+func (_u *AuthIdentityUpdate) ClearUser() *AuthIdentityUpdate {
+ _u.mutation.ClearUser()
+ return _u
+}
+
+// ClearChannels clears all "channels" edges to the AuthIdentityChannel entity.
+func (_u *AuthIdentityUpdate) ClearChannels() *AuthIdentityUpdate {
+ _u.mutation.ClearChannels()
+ return _u
+}
+
+// RemoveChannelIDs removes the "channels" edge to AuthIdentityChannel entities by IDs.
+func (_u *AuthIdentityUpdate) RemoveChannelIDs(ids ...int64) *AuthIdentityUpdate {
+ _u.mutation.RemoveChannelIDs(ids...)
+ return _u
+}
+
+// RemoveChannels removes "channels" edges to AuthIdentityChannel entities.
+func (_u *AuthIdentityUpdate) RemoveChannels(v ...*AuthIdentityChannel) *AuthIdentityUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveChannelIDs(ids...)
+}
+
+// ClearAdoptionDecisions clears all "adoption_decisions" edges to the IdentityAdoptionDecision entity.
+func (_u *AuthIdentityUpdate) ClearAdoptionDecisions() *AuthIdentityUpdate {
+ _u.mutation.ClearAdoptionDecisions()
+ return _u
+}
+
+// RemoveAdoptionDecisionIDs removes the "adoption_decisions" edge to IdentityAdoptionDecision entities by IDs.
+func (_u *AuthIdentityUpdate) RemoveAdoptionDecisionIDs(ids ...int64) *AuthIdentityUpdate {
+ _u.mutation.RemoveAdoptionDecisionIDs(ids...)
+ return _u
+}
+
+// RemoveAdoptionDecisions removes "adoption_decisions" edges to IdentityAdoptionDecision entities.
+func (_u *AuthIdentityUpdate) RemoveAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveAdoptionDecisionIDs(ids...)
+}
+
+// Save executes the query and returns the number of nodes affected by the update operation.
+func (_u *AuthIdentityUpdate) Save(ctx context.Context) (int, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *AuthIdentityUpdate) SaveX(ctx context.Context) int {
+ affected, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return affected
+}
+
+// Exec executes the query.
+func (_u *AuthIdentityUpdate) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *AuthIdentityUpdate) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *AuthIdentityUpdate) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := authidentity.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *AuthIdentityUpdate) check() error {
+ if v, ok := _u.mutation.ProviderType(); ok {
+ if err := authidentity.ProviderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_type": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderKey(); ok {
+ if err := authidentity.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_key": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderSubject(); ok {
+ if err := authidentity.ProviderSubjectValidator(v); err != nil {
+ return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_subject": %w`, err)}
+ }
+ }
+ if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
+ return errors.New(`ent: clearing a required unique edge "AuthIdentity.user"`)
+ }
+ return nil
+}
+
+func (_u *AuthIdentityUpdate) sqlSave(ctx context.Context) (_node int, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(authidentity.Table, authidentity.Columns, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64))
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(authidentity.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.ProviderType(); ok {
+ _spec.SetField(authidentity.FieldProviderType, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderKey(); ok {
+ _spec.SetField(authidentity.FieldProviderKey, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderSubject(); ok {
+ _spec.SetField(authidentity.FieldProviderSubject, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.VerifiedAt(); ok {
+ _spec.SetField(authidentity.FieldVerifiedAt, field.TypeTime, value)
+ }
+ if _u.mutation.VerifiedAtCleared() {
+ _spec.ClearField(authidentity.FieldVerifiedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.Issuer(); ok {
+ _spec.SetField(authidentity.FieldIssuer, field.TypeString, value)
+ }
+ if _u.mutation.IssuerCleared() {
+ _spec.ClearField(authidentity.FieldIssuer, field.TypeString)
+ }
+ if value, ok := _u.mutation.Metadata(); ok {
+ _spec.SetField(authidentity.FieldMetadata, field.TypeJSON, value)
+ }
+ if _u.mutation.UserCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentity.UserTable,
+ Columns: []string{authidentity.UserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.UserIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentity.UserTable,
+ Columns: []string{authidentity.UserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.ChannelsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.ChannelsTable,
+ Columns: []string{authidentity.ChannelsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedChannelsIDs(); len(nodes) > 0 && !_u.mutation.ChannelsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.ChannelsTable,
+ Columns: []string{authidentity.ChannelsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.ChannelsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.ChannelsTable,
+ Columns: []string{authidentity.ChannelsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.AdoptionDecisionsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.AdoptionDecisionsTable,
+ Columns: []string{authidentity.AdoptionDecisionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedAdoptionDecisionsIDs(); len(nodes) > 0 && !_u.mutation.AdoptionDecisionsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.AdoptionDecisionsTable,
+ Columns: []string{authidentity.AdoptionDecisionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.AdoptionDecisionsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.AdoptionDecisionsTable,
+ Columns: []string{authidentity.AdoptionDecisionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{authidentity.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return 0, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
+
+// AuthIdentityUpdateOne is the builder for updating a single AuthIdentity entity.
+type AuthIdentityUpdateOne struct {
+ config
+ fields []string
+ hooks []Hook
+ mutation *AuthIdentityMutation
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *AuthIdentityUpdateOne) SetUpdatedAt(v time.Time) *AuthIdentityUpdateOne {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetUserID sets the "user_id" field.
+func (_u *AuthIdentityUpdateOne) SetUserID(v int64) *AuthIdentityUpdateOne {
+ _u.mutation.SetUserID(v)
+ return _u
+}
+
+// SetNillableUserID sets the "user_id" field if the given value is not nil.
+func (_u *AuthIdentityUpdateOne) SetNillableUserID(v *int64) *AuthIdentityUpdateOne {
+ if v != nil {
+ _u.SetUserID(*v)
+ }
+ return _u
+}
+
+// SetProviderType sets the "provider_type" field.
+func (_u *AuthIdentityUpdateOne) SetProviderType(v string) *AuthIdentityUpdateOne {
+ _u.mutation.SetProviderType(v)
+ return _u
+}
+
+// SetNillableProviderType sets the "provider_type" field if the given value is not nil.
+func (_u *AuthIdentityUpdateOne) SetNillableProviderType(v *string) *AuthIdentityUpdateOne {
+ if v != nil {
+ _u.SetProviderType(*v)
+ }
+ return _u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_u *AuthIdentityUpdateOne) SetProviderKey(v string) *AuthIdentityUpdateOne {
+ _u.mutation.SetProviderKey(v)
+ return _u
+}
+
+// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
+func (_u *AuthIdentityUpdateOne) SetNillableProviderKey(v *string) *AuthIdentityUpdateOne {
+ if v != nil {
+ _u.SetProviderKey(*v)
+ }
+ return _u
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (_u *AuthIdentityUpdateOne) SetProviderSubject(v string) *AuthIdentityUpdateOne {
+ _u.mutation.SetProviderSubject(v)
+ return _u
+}
+
+// SetNillableProviderSubject sets the "provider_subject" field if the given value is not nil.
+func (_u *AuthIdentityUpdateOne) SetNillableProviderSubject(v *string) *AuthIdentityUpdateOne {
+ if v != nil {
+ _u.SetProviderSubject(*v)
+ }
+ return _u
+}
+
+// SetVerifiedAt sets the "verified_at" field.
+func (_u *AuthIdentityUpdateOne) SetVerifiedAt(v time.Time) *AuthIdentityUpdateOne {
+ _u.mutation.SetVerifiedAt(v)
+ return _u
+}
+
+// SetNillableVerifiedAt sets the "verified_at" field if the given value is not nil.
+func (_u *AuthIdentityUpdateOne) SetNillableVerifiedAt(v *time.Time) *AuthIdentityUpdateOne {
+ if v != nil {
+ _u.SetVerifiedAt(*v)
+ }
+ return _u
+}
+
+// ClearVerifiedAt clears the value of the "verified_at" field.
+func (_u *AuthIdentityUpdateOne) ClearVerifiedAt() *AuthIdentityUpdateOne {
+ _u.mutation.ClearVerifiedAt()
+ return _u
+}
+
+// SetIssuer sets the "issuer" field.
+func (_u *AuthIdentityUpdateOne) SetIssuer(v string) *AuthIdentityUpdateOne {
+ _u.mutation.SetIssuer(v)
+ return _u
+}
+
+// SetNillableIssuer sets the "issuer" field if the given value is not nil.
+func (_u *AuthIdentityUpdateOne) SetNillableIssuer(v *string) *AuthIdentityUpdateOne {
+ if v != nil {
+ _u.SetIssuer(*v)
+ }
+ return _u
+}
+
+// ClearIssuer clears the value of the "issuer" field.
+func (_u *AuthIdentityUpdateOne) ClearIssuer() *AuthIdentityUpdateOne {
+ _u.mutation.ClearIssuer()
+ return _u
+}
+
+// SetMetadata sets the "metadata" field.
+func (_u *AuthIdentityUpdateOne) SetMetadata(v map[string]interface{}) *AuthIdentityUpdateOne {
+ _u.mutation.SetMetadata(v)
+ return _u
+}
+
+// SetUser sets the "user" edge to the User entity.
+func (_u *AuthIdentityUpdateOne) SetUser(v *User) *AuthIdentityUpdateOne {
+ return _u.SetUserID(v.ID)
+}
+
+// AddChannelIDs adds the "channels" edge to the AuthIdentityChannel entity by IDs.
+func (_u *AuthIdentityUpdateOne) AddChannelIDs(ids ...int64) *AuthIdentityUpdateOne {
+ _u.mutation.AddChannelIDs(ids...)
+ return _u
+}
+
+// AddChannels adds the "channels" edges to the AuthIdentityChannel entity.
+func (_u *AuthIdentityUpdateOne) AddChannels(v ...*AuthIdentityChannel) *AuthIdentityUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddChannelIDs(ids...)
+}
+
+// AddAdoptionDecisionIDs adds the "adoption_decisions" edge to the IdentityAdoptionDecision entity by IDs.
+func (_u *AuthIdentityUpdateOne) AddAdoptionDecisionIDs(ids ...int64) *AuthIdentityUpdateOne {
+ _u.mutation.AddAdoptionDecisionIDs(ids...)
+ return _u
+}
+
+// AddAdoptionDecisions adds the "adoption_decisions" edges to the IdentityAdoptionDecision entity.
+func (_u *AuthIdentityUpdateOne) AddAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddAdoptionDecisionIDs(ids...)
+}
+
+// Mutation returns the AuthIdentityMutation object of the builder.
+func (_u *AuthIdentityUpdateOne) Mutation() *AuthIdentityMutation {
+ return _u.mutation
+}
+
+// ClearUser clears the "user" edge to the User entity.
+func (_u *AuthIdentityUpdateOne) ClearUser() *AuthIdentityUpdateOne {
+ _u.mutation.ClearUser()
+ return _u
+}
+
+// ClearChannels clears all "channels" edges to the AuthIdentityChannel entity.
+func (_u *AuthIdentityUpdateOne) ClearChannels() *AuthIdentityUpdateOne {
+ _u.mutation.ClearChannels()
+ return _u
+}
+
+// RemoveChannelIDs removes the "channels" edge to AuthIdentityChannel entities by IDs.
+func (_u *AuthIdentityUpdateOne) RemoveChannelIDs(ids ...int64) *AuthIdentityUpdateOne {
+ _u.mutation.RemoveChannelIDs(ids...)
+ return _u
+}
+
+// RemoveChannels removes "channels" edges to AuthIdentityChannel entities.
+func (_u *AuthIdentityUpdateOne) RemoveChannels(v ...*AuthIdentityChannel) *AuthIdentityUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveChannelIDs(ids...)
+}
+
+// ClearAdoptionDecisions clears all "adoption_decisions" edges to the IdentityAdoptionDecision entity.
+func (_u *AuthIdentityUpdateOne) ClearAdoptionDecisions() *AuthIdentityUpdateOne {
+ _u.mutation.ClearAdoptionDecisions()
+ return _u
+}
+
+// RemoveAdoptionDecisionIDs removes the "adoption_decisions" edge to IdentityAdoptionDecision entities by IDs.
+func (_u *AuthIdentityUpdateOne) RemoveAdoptionDecisionIDs(ids ...int64) *AuthIdentityUpdateOne {
+ _u.mutation.RemoveAdoptionDecisionIDs(ids...)
+ return _u
+}
+
+// RemoveAdoptionDecisions removes "adoption_decisions" edges to IdentityAdoptionDecision entities.
+func (_u *AuthIdentityUpdateOne) RemoveAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveAdoptionDecisionIDs(ids...)
+}
+
+// Where appends a list predicates to the AuthIdentityUpdate builder.
+func (_u *AuthIdentityUpdateOne) Where(ps ...predicate.AuthIdentity) *AuthIdentityUpdateOne {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// Select allows selecting one or more fields (columns) of the returned entity.
+// The default is selecting all fields defined in the entity schema.
+func (_u *AuthIdentityUpdateOne) Select(field string, fields ...string) *AuthIdentityUpdateOne {
+ _u.fields = append([]string{field}, fields...)
+ return _u
+}
+
+// Save executes the query and returns the updated AuthIdentity entity.
+func (_u *AuthIdentityUpdateOne) Save(ctx context.Context) (*AuthIdentity, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *AuthIdentityUpdateOne) SaveX(ctx context.Context) *AuthIdentity {
+ node, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// Exec executes the query on the entity.
+func (_u *AuthIdentityUpdateOne) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *AuthIdentityUpdateOne) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *AuthIdentityUpdateOne) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := authidentity.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *AuthIdentityUpdateOne) check() error {
+ if v, ok := _u.mutation.ProviderType(); ok {
+ if err := authidentity.ProviderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_type": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderKey(); ok {
+ if err := authidentity.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_key": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderSubject(); ok {
+ if err := authidentity.ProviderSubjectValidator(v); err != nil {
+ return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_subject": %w`, err)}
+ }
+ }
+ if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
+ return errors.New(`ent: clearing a required unique edge "AuthIdentity.user"`)
+ }
+ return nil
+}
+
+func (_u *AuthIdentityUpdateOne) sqlSave(ctx context.Context) (_node *AuthIdentity, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(authidentity.Table, authidentity.Columns, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64))
+ id, ok := _u.mutation.ID()
+ if !ok {
+ return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "AuthIdentity.id" for update`)}
+ }
+ _spec.Node.ID.Value = id
+ if fields := _u.fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, authidentity.FieldID)
+ for _, f := range fields {
+ if !authidentity.ValidColumn(f) {
+ return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ if f != authidentity.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, f)
+ }
+ }
+ }
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(authidentity.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.ProviderType(); ok {
+ _spec.SetField(authidentity.FieldProviderType, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderKey(); ok {
+ _spec.SetField(authidentity.FieldProviderKey, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderSubject(); ok {
+ _spec.SetField(authidentity.FieldProviderSubject, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.VerifiedAt(); ok {
+ _spec.SetField(authidentity.FieldVerifiedAt, field.TypeTime, value)
+ }
+ if _u.mutation.VerifiedAtCleared() {
+ _spec.ClearField(authidentity.FieldVerifiedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.Issuer(); ok {
+ _spec.SetField(authidentity.FieldIssuer, field.TypeString, value)
+ }
+ if _u.mutation.IssuerCleared() {
+ _spec.ClearField(authidentity.FieldIssuer, field.TypeString)
+ }
+ if value, ok := _u.mutation.Metadata(); ok {
+ _spec.SetField(authidentity.FieldMetadata, field.TypeJSON, value)
+ }
+ if _u.mutation.UserCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentity.UserTable,
+ Columns: []string{authidentity.UserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.UserIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentity.UserTable,
+ Columns: []string{authidentity.UserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.ChannelsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.ChannelsTable,
+ Columns: []string{authidentity.ChannelsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedChannelsIDs(); len(nodes) > 0 && !_u.mutation.ChannelsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.ChannelsTable,
+ Columns: []string{authidentity.ChannelsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.ChannelsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.ChannelsTable,
+ Columns: []string{authidentity.ChannelsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.AdoptionDecisionsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.AdoptionDecisionsTable,
+ Columns: []string{authidentity.AdoptionDecisionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedAdoptionDecisionsIDs(); len(nodes) > 0 && !_u.mutation.AdoptionDecisionsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.AdoptionDecisionsTable,
+ Columns: []string{authidentity.AdoptionDecisionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.AdoptionDecisionsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.AdoptionDecisionsTable,
+ Columns: []string{authidentity.AdoptionDecisionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ _node = &AuthIdentity{config: _u.config}
+ _spec.Assign = _node.assignValues
+ _spec.ScanValues = _node.scanValues
+ if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{authidentity.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
diff --git a/backend/ent/authidentitychannel.go b/backend/ent/authidentitychannel.go
new file mode 100644
index 00000000..1ff3e5d1
--- /dev/null
+++ b/backend/ent/authidentitychannel.go
@@ -0,0 +1,228 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "encoding/json"
+ "fmt"
+ "strings"
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/sql"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+)
+
+// AuthIdentityChannel is the model entity for the AuthIdentityChannel schema.
+type AuthIdentityChannel struct {
+ config `json:"-"`
+ // ID of the ent.
+ ID int64 `json:"id,omitempty"`
+ // CreatedAt holds the value of the "created_at" field.
+ CreatedAt time.Time `json:"created_at,omitempty"`
+ // UpdatedAt holds the value of the "updated_at" field.
+ UpdatedAt time.Time `json:"updated_at,omitempty"`
+ // IdentityID holds the value of the "identity_id" field.
+ IdentityID int64 `json:"identity_id,omitempty"`
+ // ProviderType holds the value of the "provider_type" field.
+ ProviderType string `json:"provider_type,omitempty"`
+ // ProviderKey holds the value of the "provider_key" field.
+ ProviderKey string `json:"provider_key,omitempty"`
+ // Channel holds the value of the "channel" field.
+ Channel string `json:"channel,omitempty"`
+ // ChannelAppID holds the value of the "channel_app_id" field.
+ ChannelAppID string `json:"channel_app_id,omitempty"`
+ // ChannelSubject holds the value of the "channel_subject" field.
+ ChannelSubject string `json:"channel_subject,omitempty"`
+ // Metadata holds the value of the "metadata" field.
+ Metadata map[string]interface{} `json:"metadata,omitempty"`
+ // Edges holds the relations/edges for other nodes in the graph.
+ // The values are being populated by the AuthIdentityChannelQuery when eager-loading is set.
+ Edges AuthIdentityChannelEdges `json:"edges"`
+ selectValues sql.SelectValues
+}
+
+// AuthIdentityChannelEdges holds the relations/edges for other nodes in the graph.
+type AuthIdentityChannelEdges struct {
+ // Identity holds the value of the identity edge.
+ Identity *AuthIdentity `json:"identity,omitempty"`
+ // loadedTypes holds the information for reporting if a
+ // type was loaded (or requested) in eager-loading or not.
+ loadedTypes [1]bool
+}
+
+// IdentityOrErr returns the Identity value or an error if the edge
+// was not loaded in eager-loading, or loaded but was not found.
+func (e AuthIdentityChannelEdges) IdentityOrErr() (*AuthIdentity, error) {
+ if e.Identity != nil {
+ return e.Identity, nil
+ } else if e.loadedTypes[0] {
+ return nil, &NotFoundError{label: authidentity.Label}
+ }
+ return nil, &NotLoadedError{edge: "identity"}
+}
+
+// scanValues returns the types for scanning values from sql.Rows.
+func (*AuthIdentityChannel) scanValues(columns []string) ([]any, error) {
+ values := make([]any, len(columns))
+ for i := range columns {
+ switch columns[i] {
+ case authidentitychannel.FieldMetadata:
+ values[i] = new([]byte)
+ case authidentitychannel.FieldID, authidentitychannel.FieldIdentityID:
+ values[i] = new(sql.NullInt64)
+ case authidentitychannel.FieldProviderType, authidentitychannel.FieldProviderKey, authidentitychannel.FieldChannel, authidentitychannel.FieldChannelAppID, authidentitychannel.FieldChannelSubject:
+ values[i] = new(sql.NullString)
+ case authidentitychannel.FieldCreatedAt, authidentitychannel.FieldUpdatedAt:
+ values[i] = new(sql.NullTime)
+ default:
+ values[i] = new(sql.UnknownType)
+ }
+ }
+ return values, nil
+}
+
+// assignValues assigns the values that were returned from sql.Rows (after scanning)
+// to the AuthIdentityChannel fields.
+func (_m *AuthIdentityChannel) assignValues(columns []string, values []any) error {
+ if m, n := len(values), len(columns); m < n {
+ return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
+ }
+ for i := range columns {
+ switch columns[i] {
+ case authidentitychannel.FieldID:
+ value, ok := values[i].(*sql.NullInt64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field id", value)
+ }
+ _m.ID = int64(value.Int64)
+ case authidentitychannel.FieldCreatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field created_at", values[i])
+ } else if value.Valid {
+ _m.CreatedAt = value.Time
+ }
+ case authidentitychannel.FieldUpdatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field updated_at", values[i])
+ } else if value.Valid {
+ _m.UpdatedAt = value.Time
+ }
+ case authidentitychannel.FieldIdentityID:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field identity_id", values[i])
+ } else if value.Valid {
+ _m.IdentityID = value.Int64
+ }
+ case authidentitychannel.FieldProviderType:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_type", values[i])
+ } else if value.Valid {
+ _m.ProviderType = value.String
+ }
+ case authidentitychannel.FieldProviderKey:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_key", values[i])
+ } else if value.Valid {
+ _m.ProviderKey = value.String
+ }
+ case authidentitychannel.FieldChannel:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field channel", values[i])
+ } else if value.Valid {
+ _m.Channel = value.String
+ }
+ case authidentitychannel.FieldChannelAppID:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field channel_app_id", values[i])
+ } else if value.Valid {
+ _m.ChannelAppID = value.String
+ }
+ case authidentitychannel.FieldChannelSubject:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field channel_subject", values[i])
+ } else if value.Valid {
+ _m.ChannelSubject = value.String
+ }
+ case authidentitychannel.FieldMetadata:
+ if value, ok := values[i].(*[]byte); !ok {
+ return fmt.Errorf("unexpected type %T for field metadata", values[i])
+ } else if value != nil && len(*value) > 0 {
+ if err := json.Unmarshal(*value, &_m.Metadata); err != nil {
+ return fmt.Errorf("unmarshal field metadata: %w", err)
+ }
+ }
+ default:
+ _m.selectValues.Set(columns[i], values[i])
+ }
+ }
+ return nil
+}
+
+// Value returns the ent.Value that was dynamically selected and assigned to the AuthIdentityChannel.
+// This includes values selected through modifiers, order, etc.
+func (_m *AuthIdentityChannel) Value(name string) (ent.Value, error) {
+ return _m.selectValues.Get(name)
+}
+
+// QueryIdentity queries the "identity" edge of the AuthIdentityChannel entity.
+func (_m *AuthIdentityChannel) QueryIdentity() *AuthIdentityQuery {
+ return NewAuthIdentityChannelClient(_m.config).QueryIdentity(_m)
+}
+
+// Update returns a builder for updating this AuthIdentityChannel.
+// Note that you need to call AuthIdentityChannel.Unwrap() before calling this method if this AuthIdentityChannel
+// was returned from a transaction, and the transaction was committed or rolled back.
+func (_m *AuthIdentityChannel) Update() *AuthIdentityChannelUpdateOne {
+ return NewAuthIdentityChannelClient(_m.config).UpdateOne(_m)
+}
+
+// Unwrap unwraps the AuthIdentityChannel entity that was returned from a transaction after it was closed,
+// so that all future queries will be executed through the driver which created the transaction.
+func (_m *AuthIdentityChannel) Unwrap() *AuthIdentityChannel {
+ _tx, ok := _m.config.driver.(*txDriver)
+ if !ok {
+ panic("ent: AuthIdentityChannel is not a transactional entity")
+ }
+ _m.config.driver = _tx.drv
+ return _m
+}
+
+// String implements the fmt.Stringer.
+func (_m *AuthIdentityChannel) String() string {
+ var builder strings.Builder
+ builder.WriteString("AuthIdentityChannel(")
+ builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
+ builder.WriteString("created_at=")
+ builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("updated_at=")
+ builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("identity_id=")
+ builder.WriteString(fmt.Sprintf("%v", _m.IdentityID))
+ builder.WriteString(", ")
+ builder.WriteString("provider_type=")
+ builder.WriteString(_m.ProviderType)
+ builder.WriteString(", ")
+ builder.WriteString("provider_key=")
+ builder.WriteString(_m.ProviderKey)
+ builder.WriteString(", ")
+ builder.WriteString("channel=")
+ builder.WriteString(_m.Channel)
+ builder.WriteString(", ")
+ builder.WriteString("channel_app_id=")
+ builder.WriteString(_m.ChannelAppID)
+ builder.WriteString(", ")
+ builder.WriteString("channel_subject=")
+ builder.WriteString(_m.ChannelSubject)
+ builder.WriteString(", ")
+ builder.WriteString("metadata=")
+ builder.WriteString(fmt.Sprintf("%v", _m.Metadata))
+ builder.WriteByte(')')
+ return builder.String()
+}
+
+// AuthIdentityChannels is a parsable slice of AuthIdentityChannel.
+type AuthIdentityChannels []*AuthIdentityChannel
diff --git a/backend/ent/authidentitychannel/authidentitychannel.go b/backend/ent/authidentitychannel/authidentitychannel.go
new file mode 100644
index 00000000..7dcc98bb
--- /dev/null
+++ b/backend/ent/authidentitychannel/authidentitychannel.go
@@ -0,0 +1,153 @@
+// Code generated by ent, DO NOT EDIT.
+
+package authidentitychannel
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+)
+
+const (
+ // Label holds the string label denoting the authidentitychannel type in the database.
+ Label = "auth_identity_channel"
+ // FieldID holds the string denoting the id field in the database.
+ FieldID = "id"
+ // FieldCreatedAt holds the string denoting the created_at field in the database.
+ FieldCreatedAt = "created_at"
+ // FieldUpdatedAt holds the string denoting the updated_at field in the database.
+ FieldUpdatedAt = "updated_at"
+ // FieldIdentityID holds the string denoting the identity_id field in the database.
+ FieldIdentityID = "identity_id"
+ // FieldProviderType holds the string denoting the provider_type field in the database.
+ FieldProviderType = "provider_type"
+ // FieldProviderKey holds the string denoting the provider_key field in the database.
+ FieldProviderKey = "provider_key"
+ // FieldChannel holds the string denoting the channel field in the database.
+ FieldChannel = "channel"
+ // FieldChannelAppID holds the string denoting the channel_app_id field in the database.
+ FieldChannelAppID = "channel_app_id"
+ // FieldChannelSubject holds the string denoting the channel_subject field in the database.
+ FieldChannelSubject = "channel_subject"
+ // FieldMetadata holds the string denoting the metadata field in the database.
+ FieldMetadata = "metadata"
+ // EdgeIdentity holds the string denoting the identity edge name in mutations.
+ EdgeIdentity = "identity"
+ // Table holds the table name of the authidentitychannel in the database.
+ Table = "auth_identity_channels"
+ // IdentityTable is the table that holds the identity relation/edge.
+ IdentityTable = "auth_identity_channels"
+ // IdentityInverseTable is the table name for the AuthIdentity entity.
+ // It exists in this package in order to avoid circular dependency with the "authidentity" package.
+ IdentityInverseTable = "auth_identities"
+ // IdentityColumn is the table column denoting the identity relation/edge.
+ IdentityColumn = "identity_id"
+)
+
+// Columns holds all SQL columns for authidentitychannel fields.
+var Columns = []string{
+ FieldID,
+ FieldCreatedAt,
+ FieldUpdatedAt,
+ FieldIdentityID,
+ FieldProviderType,
+ FieldProviderKey,
+ FieldChannel,
+ FieldChannelAppID,
+ FieldChannelSubject,
+ FieldMetadata,
+}
+
+// ValidColumn reports if the column name is valid (part of the table columns).
+func ValidColumn(column string) bool {
+ for i := range Columns {
+ if column == Columns[i] {
+ return true
+ }
+ }
+ return false
+}
+
+var (
+ // DefaultCreatedAt holds the default value on creation for the "created_at" field.
+ DefaultCreatedAt func() time.Time
+ // DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
+ DefaultUpdatedAt func() time.Time
+ // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
+ UpdateDefaultUpdatedAt func() time.Time
+ // ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save.
+ ProviderTypeValidator func(string) error
+ // ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
+ ProviderKeyValidator func(string) error
+ // ChannelValidator is a validator for the "channel" field. It is called by the builders before save.
+ ChannelValidator func(string) error
+ // ChannelAppIDValidator is a validator for the "channel_app_id" field. It is called by the builders before save.
+ ChannelAppIDValidator func(string) error
+ // ChannelSubjectValidator is a validator for the "channel_subject" field. It is called by the builders before save.
+ ChannelSubjectValidator func(string) error
+ // DefaultMetadata holds the default value on creation for the "metadata" field.
+ DefaultMetadata func() map[string]interface{}
+)
+
+// OrderOption defines the ordering options for the AuthIdentityChannel queries.
+type OrderOption func(*sql.Selector)
+
+// ByID orders the results by the id field.
+func ByID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldID, opts...).ToFunc()
+}
+
+// ByCreatedAt orders the results by the created_at field.
+func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
+}
+
+// ByUpdatedAt orders the results by the updated_at field.
+func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
+}
+
+// ByIdentityID orders the results by the identity_id field.
+func ByIdentityID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldIdentityID, opts...).ToFunc()
+}
+
+// ByProviderType orders the results by the provider_type field.
+func ByProviderType(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProviderType, opts...).ToFunc()
+}
+
+// ByProviderKey orders the results by the provider_key field.
+func ByProviderKey(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProviderKey, opts...).ToFunc()
+}
+
+// ByChannel orders the results by the channel field.
+func ByChannel(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldChannel, opts...).ToFunc()
+}
+
+// ByChannelAppID orders the results by the channel_app_id field.
+func ByChannelAppID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldChannelAppID, opts...).ToFunc()
+}
+
+// ByChannelSubject orders the results by the channel_subject field.
+func ByChannelSubject(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldChannelSubject, opts...).ToFunc()
+}
+
+// ByIdentityField orders the results by identity field.
+func ByIdentityField(field string, opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newIdentityStep(), sql.OrderByField(field, opts...))
+ }
+}
+func newIdentityStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(IdentityInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, IdentityTable, IdentityColumn),
+ )
+}
diff --git a/backend/ent/authidentitychannel/where.go b/backend/ent/authidentitychannel/where.go
new file mode 100644
index 00000000..827dc384
--- /dev/null
+++ b/backend/ent/authidentitychannel/where.go
@@ -0,0 +1,559 @@
+// Code generated by ent, DO NOT EDIT.
+
+package authidentitychannel
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ID filters vertices based on their ID field.
+func ID(id int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldID, id))
+}
+
+// IDEQ applies the EQ predicate on the ID field.
+func IDEQ(id int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldID, id))
+}
+
+// IDNEQ applies the NEQ predicate on the ID field.
+func IDNEQ(id int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldID, id))
+}
+
+// IDIn applies the In predicate on the ID field.
+func IDIn(ids ...int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldIn(FieldID, ids...))
+}
+
+// IDNotIn applies the NotIn predicate on the ID field.
+func IDNotIn(ids ...int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldID, ids...))
+}
+
+// IDGT applies the GT predicate on the ID field.
+func IDGT(id int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGT(FieldID, id))
+}
+
+// IDGTE applies the GTE predicate on the ID field.
+func IDGTE(id int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGTE(FieldID, id))
+}
+
+// IDLT applies the LT predicate on the ID field.
+func IDLT(id int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLT(FieldID, id))
+}
+
+// IDLTE applies the LTE predicate on the ID field.
+func IDLTE(id int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLTE(FieldID, id))
+}
+
+// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
+func CreatedAt(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
+func UpdatedAt(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// IdentityID applies equality check predicate on the "identity_id" field. It's identical to IdentityIDEQ.
+func IdentityID(v int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldIdentityID, v))
+}
+
+// ProviderType applies equality check predicate on the "provider_type" field. It's identical to ProviderTypeEQ.
+func ProviderType(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldProviderType, v))
+}
+
+// ProviderKey applies equality check predicate on the "provider_key" field. It's identical to ProviderKeyEQ.
+func ProviderKey(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldProviderKey, v))
+}
+
+// Channel applies equality check predicate on the "channel" field. It's identical to ChannelEQ.
+func Channel(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannel, v))
+}
+
+// ChannelAppID applies equality check predicate on the "channel_app_id" field. It's identical to ChannelAppIDEQ.
+func ChannelAppID(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannelAppID, v))
+}
+
+// ChannelSubject applies equality check predicate on the "channel_subject" field. It's identical to ChannelSubjectEQ.
+func ChannelSubject(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannelSubject, v))
+}
+
+// CreatedAtEQ applies the EQ predicate on the "created_at" field.
+func CreatedAtEQ(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
+func CreatedAtNEQ(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtIn applies the In predicate on the "created_at" field.
+func CreatedAtIn(vs ...time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
+func CreatedAtNotIn(vs ...time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtGT applies the GT predicate on the "created_at" field.
+func CreatedAtGT(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGT(FieldCreatedAt, v))
+}
+
+// CreatedAtGTE applies the GTE predicate on the "created_at" field.
+func CreatedAtGTE(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGTE(FieldCreatedAt, v))
+}
+
+// CreatedAtLT applies the LT predicate on the "created_at" field.
+func CreatedAtLT(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLT(FieldCreatedAt, v))
+}
+
+// CreatedAtLTE applies the LTE predicate on the "created_at" field.
+func CreatedAtLTE(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLTE(FieldCreatedAt, v))
+}
+
+// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
+func UpdatedAtEQ(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
+func UpdatedAtNEQ(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtIn applies the In predicate on the "updated_at" field.
+func UpdatedAtIn(vs ...time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
+func UpdatedAtNotIn(vs ...time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtGT applies the GT predicate on the "updated_at" field.
+func UpdatedAtGT(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
+func UpdatedAtGTE(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGTE(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLT applies the LT predicate on the "updated_at" field.
+func UpdatedAtLT(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
+func UpdatedAtLTE(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLTE(FieldUpdatedAt, v))
+}
+
+// IdentityIDEQ applies the EQ predicate on the "identity_id" field.
+func IdentityIDEQ(v int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldIdentityID, v))
+}
+
+// IdentityIDNEQ applies the NEQ predicate on the "identity_id" field.
+func IdentityIDNEQ(v int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldIdentityID, v))
+}
+
+// IdentityIDIn applies the In predicate on the "identity_id" field.
+func IdentityIDIn(vs ...int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldIn(FieldIdentityID, vs...))
+}
+
+// IdentityIDNotIn applies the NotIn predicate on the "identity_id" field.
+func IdentityIDNotIn(vs ...int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldIdentityID, vs...))
+}
+
+// ProviderTypeEQ applies the EQ predicate on the "provider_type" field.
+func ProviderTypeEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldProviderType, v))
+}
+
+// ProviderTypeNEQ applies the NEQ predicate on the "provider_type" field.
+func ProviderTypeNEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldProviderType, v))
+}
+
+// ProviderTypeIn applies the In predicate on the "provider_type" field.
+func ProviderTypeIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldIn(FieldProviderType, vs...))
+}
+
+// ProviderTypeNotIn applies the NotIn predicate on the "provider_type" field.
+func ProviderTypeNotIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldProviderType, vs...))
+}
+
+// ProviderTypeGT applies the GT predicate on the "provider_type" field.
+func ProviderTypeGT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGT(FieldProviderType, v))
+}
+
+// ProviderTypeGTE applies the GTE predicate on the "provider_type" field.
+func ProviderTypeGTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGTE(FieldProviderType, v))
+}
+
+// ProviderTypeLT applies the LT predicate on the "provider_type" field.
+func ProviderTypeLT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLT(FieldProviderType, v))
+}
+
+// ProviderTypeLTE applies the LTE predicate on the "provider_type" field.
+func ProviderTypeLTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLTE(FieldProviderType, v))
+}
+
+// ProviderTypeContains applies the Contains predicate on the "provider_type" field.
+func ProviderTypeContains(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContains(FieldProviderType, v))
+}
+
+// ProviderTypeHasPrefix applies the HasPrefix predicate on the "provider_type" field.
+func ProviderTypeHasPrefix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldProviderType, v))
+}
+
+// ProviderTypeHasSuffix applies the HasSuffix predicate on the "provider_type" field.
+func ProviderTypeHasSuffix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldProviderType, v))
+}
+
+// ProviderTypeEqualFold applies the EqualFold predicate on the "provider_type" field.
+func ProviderTypeEqualFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldProviderType, v))
+}
+
+// ProviderTypeContainsFold applies the ContainsFold predicate on the "provider_type" field.
+func ProviderTypeContainsFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldProviderType, v))
+}
+
+// ProviderKeyEQ applies the EQ predicate on the "provider_key" field.
+func ProviderKeyEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldProviderKey, v))
+}
+
+// ProviderKeyNEQ applies the NEQ predicate on the "provider_key" field.
+func ProviderKeyNEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldProviderKey, v))
+}
+
+// ProviderKeyIn applies the In predicate on the "provider_key" field.
+func ProviderKeyIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldIn(FieldProviderKey, vs...))
+}
+
+// ProviderKeyNotIn applies the NotIn predicate on the "provider_key" field.
+func ProviderKeyNotIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldProviderKey, vs...))
+}
+
+// ProviderKeyGT applies the GT predicate on the "provider_key" field.
+func ProviderKeyGT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGT(FieldProviderKey, v))
+}
+
+// ProviderKeyGTE applies the GTE predicate on the "provider_key" field.
+func ProviderKeyGTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGTE(FieldProviderKey, v))
+}
+
+// ProviderKeyLT applies the LT predicate on the "provider_key" field.
+func ProviderKeyLT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLT(FieldProviderKey, v))
+}
+
+// ProviderKeyLTE applies the LTE predicate on the "provider_key" field.
+func ProviderKeyLTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLTE(FieldProviderKey, v))
+}
+
+// ProviderKeyContains applies the Contains predicate on the "provider_key" field.
+func ProviderKeyContains(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContains(FieldProviderKey, v))
+}
+
+// ProviderKeyHasPrefix applies the HasPrefix predicate on the "provider_key" field.
+func ProviderKeyHasPrefix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldProviderKey, v))
+}
+
+// ProviderKeyHasSuffix applies the HasSuffix predicate on the "provider_key" field.
+func ProviderKeyHasSuffix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldProviderKey, v))
+}
+
+// ProviderKeyEqualFold applies the EqualFold predicate on the "provider_key" field.
+func ProviderKeyEqualFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldProviderKey, v))
+}
+
+// ProviderKeyContainsFold applies the ContainsFold predicate on the "provider_key" field.
+func ProviderKeyContainsFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldProviderKey, v))
+}
+
+// ChannelEQ applies the EQ predicate on the "channel" field.
+func ChannelEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannel, v))
+}
+
+// ChannelNEQ applies the NEQ predicate on the "channel" field.
+func ChannelNEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldChannel, v))
+}
+
+// ChannelIn applies the In predicate on the "channel" field.
+func ChannelIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldIn(FieldChannel, vs...))
+}
+
+// ChannelNotIn applies the NotIn predicate on the "channel" field.
+func ChannelNotIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldChannel, vs...))
+}
+
+// ChannelGT applies the GT predicate on the "channel" field.
+func ChannelGT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGT(FieldChannel, v))
+}
+
+// ChannelGTE applies the GTE predicate on the "channel" field.
+func ChannelGTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGTE(FieldChannel, v))
+}
+
+// ChannelLT applies the LT predicate on the "channel" field.
+func ChannelLT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLT(FieldChannel, v))
+}
+
+// ChannelLTE applies the LTE predicate on the "channel" field.
+func ChannelLTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLTE(FieldChannel, v))
+}
+
+// ChannelContains applies the Contains predicate on the "channel" field.
+func ChannelContains(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContains(FieldChannel, v))
+}
+
+// ChannelHasPrefix applies the HasPrefix predicate on the "channel" field.
+func ChannelHasPrefix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldChannel, v))
+}
+
+// ChannelHasSuffix applies the HasSuffix predicate on the "channel" field.
+func ChannelHasSuffix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldChannel, v))
+}
+
+// ChannelEqualFold applies the EqualFold predicate on the "channel" field.
+func ChannelEqualFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldChannel, v))
+}
+
+// ChannelContainsFold applies the ContainsFold predicate on the "channel" field.
+func ChannelContainsFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldChannel, v))
+}
+
+// ChannelAppIDEQ applies the EQ predicate on the "channel_app_id" field.
+func ChannelAppIDEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannelAppID, v))
+}
+
+// ChannelAppIDNEQ applies the NEQ predicate on the "channel_app_id" field.
+func ChannelAppIDNEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldChannelAppID, v))
+}
+
+// ChannelAppIDIn applies the In predicate on the "channel_app_id" field.
+func ChannelAppIDIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldIn(FieldChannelAppID, vs...))
+}
+
+// ChannelAppIDNotIn applies the NotIn predicate on the "channel_app_id" field.
+func ChannelAppIDNotIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldChannelAppID, vs...))
+}
+
+// ChannelAppIDGT applies the GT predicate on the "channel_app_id" field.
+func ChannelAppIDGT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGT(FieldChannelAppID, v))
+}
+
+// ChannelAppIDGTE applies the GTE predicate on the "channel_app_id" field.
+func ChannelAppIDGTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGTE(FieldChannelAppID, v))
+}
+
+// ChannelAppIDLT applies the LT predicate on the "channel_app_id" field.
+func ChannelAppIDLT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLT(FieldChannelAppID, v))
+}
+
+// ChannelAppIDLTE applies the LTE predicate on the "channel_app_id" field.
+func ChannelAppIDLTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLTE(FieldChannelAppID, v))
+}
+
+// ChannelAppIDContains applies the Contains predicate on the "channel_app_id" field.
+func ChannelAppIDContains(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContains(FieldChannelAppID, v))
+}
+
+// ChannelAppIDHasPrefix applies the HasPrefix predicate on the "channel_app_id" field.
+func ChannelAppIDHasPrefix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldChannelAppID, v))
+}
+
+// ChannelAppIDHasSuffix applies the HasSuffix predicate on the "channel_app_id" field.
+func ChannelAppIDHasSuffix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldChannelAppID, v))
+}
+
+// ChannelAppIDEqualFold applies the EqualFold predicate on the "channel_app_id" field.
+func ChannelAppIDEqualFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldChannelAppID, v))
+}
+
+// ChannelAppIDContainsFold applies the ContainsFold predicate on the "channel_app_id" field.
+func ChannelAppIDContainsFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldChannelAppID, v))
+}
+
+// ChannelSubjectEQ applies the EQ predicate on the "channel_subject" field.
+func ChannelSubjectEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannelSubject, v))
+}
+
+// ChannelSubjectNEQ applies the NEQ predicate on the "channel_subject" field.
+func ChannelSubjectNEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldChannelSubject, v))
+}
+
+// ChannelSubjectIn applies the In predicate on the "channel_subject" field.
+func ChannelSubjectIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldIn(FieldChannelSubject, vs...))
+}
+
+// ChannelSubjectNotIn applies the NotIn predicate on the "channel_subject" field.
+func ChannelSubjectNotIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldChannelSubject, vs...))
+}
+
+// ChannelSubjectGT applies the GT predicate on the "channel_subject" field.
+func ChannelSubjectGT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGT(FieldChannelSubject, v))
+}
+
+// ChannelSubjectGTE applies the GTE predicate on the "channel_subject" field.
+func ChannelSubjectGTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGTE(FieldChannelSubject, v))
+}
+
+// ChannelSubjectLT applies the LT predicate on the "channel_subject" field.
+func ChannelSubjectLT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLT(FieldChannelSubject, v))
+}
+
+// ChannelSubjectLTE applies the LTE predicate on the "channel_subject" field.
+func ChannelSubjectLTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLTE(FieldChannelSubject, v))
+}
+
+// ChannelSubjectContains applies the Contains predicate on the "channel_subject" field.
+func ChannelSubjectContains(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContains(FieldChannelSubject, v))
+}
+
+// ChannelSubjectHasPrefix applies the HasPrefix predicate on the "channel_subject" field.
+func ChannelSubjectHasPrefix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldChannelSubject, v))
+}
+
+// ChannelSubjectHasSuffix applies the HasSuffix predicate on the "channel_subject" field.
+func ChannelSubjectHasSuffix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldChannelSubject, v))
+}
+
+// ChannelSubjectEqualFold applies the EqualFold predicate on the "channel_subject" field.
+func ChannelSubjectEqualFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldChannelSubject, v))
+}
+
+// ChannelSubjectContainsFold applies the ContainsFold predicate on the "channel_subject" field.
+func ChannelSubjectContainsFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldChannelSubject, v))
+}
+
+// HasIdentity applies the HasEdge predicate on the "identity" edge.
+func HasIdentity() predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, IdentityTable, IdentityColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasIdentityWith applies the HasEdge predicate on the "identity" edge with a given conditions (other predicates).
+func HasIdentityWith(preds ...predicate.AuthIdentity) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(func(s *sql.Selector) {
+ step := newIdentityStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// And groups predicates with the AND operator between them.
+func And(predicates ...predicate.AuthIdentityChannel) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.AndPredicates(predicates...))
+}
+
+// Or groups predicates with the OR operator between them.
+func Or(predicates ...predicate.AuthIdentityChannel) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.OrPredicates(predicates...))
+}
+
+// Not applies the not operator on the given predicate.
+func Not(p predicate.AuthIdentityChannel) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.NotPredicates(p))
+}
diff --git a/backend/ent/authidentitychannel_create.go b/backend/ent/authidentitychannel_create.go
new file mode 100644
index 00000000..4ce28479
--- /dev/null
+++ b/backend/ent/authidentitychannel_create.go
@@ -0,0 +1,932 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+)
+
+// AuthIdentityChannelCreate is the builder for creating a AuthIdentityChannel entity.
+type AuthIdentityChannelCreate struct {
+ config
+ mutation *AuthIdentityChannelMutation
+ hooks []Hook
+ conflict []sql.ConflictOption
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (_c *AuthIdentityChannelCreate) SetCreatedAt(v time.Time) *AuthIdentityChannelCreate {
+ _c.mutation.SetCreatedAt(v)
+ return _c
+}
+
+// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
+func (_c *AuthIdentityChannelCreate) SetNillableCreatedAt(v *time.Time) *AuthIdentityChannelCreate {
+ if v != nil {
+ _c.SetCreatedAt(*v)
+ }
+ return _c
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_c *AuthIdentityChannelCreate) SetUpdatedAt(v time.Time) *AuthIdentityChannelCreate {
+ _c.mutation.SetUpdatedAt(v)
+ return _c
+}
+
+// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil.
+func (_c *AuthIdentityChannelCreate) SetNillableUpdatedAt(v *time.Time) *AuthIdentityChannelCreate {
+ if v != nil {
+ _c.SetUpdatedAt(*v)
+ }
+ return _c
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (_c *AuthIdentityChannelCreate) SetIdentityID(v int64) *AuthIdentityChannelCreate {
+ _c.mutation.SetIdentityID(v)
+ return _c
+}
+
+// SetProviderType sets the "provider_type" field.
+func (_c *AuthIdentityChannelCreate) SetProviderType(v string) *AuthIdentityChannelCreate {
+ _c.mutation.SetProviderType(v)
+ return _c
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_c *AuthIdentityChannelCreate) SetProviderKey(v string) *AuthIdentityChannelCreate {
+ _c.mutation.SetProviderKey(v)
+ return _c
+}
+
+// SetChannel sets the "channel" field.
+func (_c *AuthIdentityChannelCreate) SetChannel(v string) *AuthIdentityChannelCreate {
+ _c.mutation.SetChannel(v)
+ return _c
+}
+
+// SetChannelAppID sets the "channel_app_id" field.
+func (_c *AuthIdentityChannelCreate) SetChannelAppID(v string) *AuthIdentityChannelCreate {
+ _c.mutation.SetChannelAppID(v)
+ return _c
+}
+
+// SetChannelSubject sets the "channel_subject" field.
+func (_c *AuthIdentityChannelCreate) SetChannelSubject(v string) *AuthIdentityChannelCreate {
+ _c.mutation.SetChannelSubject(v)
+ return _c
+}
+
+// SetMetadata sets the "metadata" field.
+func (_c *AuthIdentityChannelCreate) SetMetadata(v map[string]interface{}) *AuthIdentityChannelCreate {
+ _c.mutation.SetMetadata(v)
+ return _c
+}
+
+// SetIdentity sets the "identity" edge to the AuthIdentity entity.
+func (_c *AuthIdentityChannelCreate) SetIdentity(v *AuthIdentity) *AuthIdentityChannelCreate {
+ return _c.SetIdentityID(v.ID)
+}
+
+// Mutation returns the AuthIdentityChannelMutation object of the builder.
+func (_c *AuthIdentityChannelCreate) Mutation() *AuthIdentityChannelMutation {
+ return _c.mutation
+}
+
+// Save creates the AuthIdentityChannel in the database.
+func (_c *AuthIdentityChannelCreate) Save(ctx context.Context) (*AuthIdentityChannel, error) {
+ _c.defaults()
+ return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
+}
+
+// SaveX calls Save and panics if Save returns an error.
+func (_c *AuthIdentityChannelCreate) SaveX(ctx context.Context) *AuthIdentityChannel {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *AuthIdentityChannelCreate) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *AuthIdentityChannelCreate) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_c *AuthIdentityChannelCreate) defaults() {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ v := authidentitychannel.DefaultCreatedAt()
+ _c.mutation.SetCreatedAt(v)
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ v := authidentitychannel.DefaultUpdatedAt()
+ _c.mutation.SetUpdatedAt(v)
+ }
+ if _, ok := _c.mutation.Metadata(); !ok {
+ v := authidentitychannel.DefaultMetadata()
+ _c.mutation.SetMetadata(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_c *AuthIdentityChannelCreate) check() error {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "AuthIdentityChannel.created_at"`)}
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "AuthIdentityChannel.updated_at"`)}
+ }
+ if _, ok := _c.mutation.IdentityID(); !ok {
+ return &ValidationError{Name: "identity_id", err: errors.New(`ent: missing required field "AuthIdentityChannel.identity_id"`)}
+ }
+ if _, ok := _c.mutation.ProviderType(); !ok {
+ return &ValidationError{Name: "provider_type", err: errors.New(`ent: missing required field "AuthIdentityChannel.provider_type"`)}
+ }
+ if v, ok := _c.mutation.ProviderType(); ok {
+ if err := authidentitychannel.ProviderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_type": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ProviderKey(); !ok {
+ return &ValidationError{Name: "provider_key", err: errors.New(`ent: missing required field "AuthIdentityChannel.provider_key"`)}
+ }
+ if v, ok := _c.mutation.ProviderKey(); ok {
+ if err := authidentitychannel.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_key": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.Channel(); !ok {
+ return &ValidationError{Name: "channel", err: errors.New(`ent: missing required field "AuthIdentityChannel.channel"`)}
+ }
+ if v, ok := _c.mutation.Channel(); ok {
+ if err := authidentitychannel.ChannelValidator(v); err != nil {
+ return &ValidationError{Name: "channel", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ChannelAppID(); !ok {
+ return &ValidationError{Name: "channel_app_id", err: errors.New(`ent: missing required field "AuthIdentityChannel.channel_app_id"`)}
+ }
+ if v, ok := _c.mutation.ChannelAppID(); ok {
+ if err := authidentitychannel.ChannelAppIDValidator(v); err != nil {
+ return &ValidationError{Name: "channel_app_id", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_app_id": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ChannelSubject(); !ok {
+ return &ValidationError{Name: "channel_subject", err: errors.New(`ent: missing required field "AuthIdentityChannel.channel_subject"`)}
+ }
+ if v, ok := _c.mutation.ChannelSubject(); ok {
+ if err := authidentitychannel.ChannelSubjectValidator(v); err != nil {
+ return &ValidationError{Name: "channel_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_subject": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.Metadata(); !ok {
+ return &ValidationError{Name: "metadata", err: errors.New(`ent: missing required field "AuthIdentityChannel.metadata"`)}
+ }
+ if len(_c.mutation.IdentityIDs()) == 0 {
+ return &ValidationError{Name: "identity", err: errors.New(`ent: missing required edge "AuthIdentityChannel.identity"`)}
+ }
+ return nil
+}
+
+func (_c *AuthIdentityChannelCreate) sqlSave(ctx context.Context) (*AuthIdentityChannel, error) {
+ if err := _c.check(); err != nil {
+ return nil, err
+ }
+ _node, _spec := _c.createSpec()
+ if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ id := _spec.ID.Value.(int64)
+ _node.ID = int64(id)
+ _c.mutation.id = &_node.ID
+ _c.mutation.done = true
+ return _node, nil
+}
+
+func (_c *AuthIdentityChannelCreate) createSpec() (*AuthIdentityChannel, *sqlgraph.CreateSpec) {
+ var (
+ _node = &AuthIdentityChannel{config: _c.config}
+ _spec = sqlgraph.NewCreateSpec(authidentitychannel.Table, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64))
+ )
+ _spec.OnConflict = _c.conflict
+ if value, ok := _c.mutation.CreatedAt(); ok {
+ _spec.SetField(authidentitychannel.FieldCreatedAt, field.TypeTime, value)
+ _node.CreatedAt = value
+ }
+ if value, ok := _c.mutation.UpdatedAt(); ok {
+ _spec.SetField(authidentitychannel.FieldUpdatedAt, field.TypeTime, value)
+ _node.UpdatedAt = value
+ }
+ if value, ok := _c.mutation.ProviderType(); ok {
+ _spec.SetField(authidentitychannel.FieldProviderType, field.TypeString, value)
+ _node.ProviderType = value
+ }
+ if value, ok := _c.mutation.ProviderKey(); ok {
+ _spec.SetField(authidentitychannel.FieldProviderKey, field.TypeString, value)
+ _node.ProviderKey = value
+ }
+ if value, ok := _c.mutation.Channel(); ok {
+ _spec.SetField(authidentitychannel.FieldChannel, field.TypeString, value)
+ _node.Channel = value
+ }
+ if value, ok := _c.mutation.ChannelAppID(); ok {
+ _spec.SetField(authidentitychannel.FieldChannelAppID, field.TypeString, value)
+ _node.ChannelAppID = value
+ }
+ if value, ok := _c.mutation.ChannelSubject(); ok {
+ _spec.SetField(authidentitychannel.FieldChannelSubject, field.TypeString, value)
+ _node.ChannelSubject = value
+ }
+ if value, ok := _c.mutation.Metadata(); ok {
+ _spec.SetField(authidentitychannel.FieldMetadata, field.TypeJSON, value)
+ _node.Metadata = value
+ }
+ if nodes := _c.mutation.IdentityIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentitychannel.IdentityTable,
+ Columns: []string{authidentitychannel.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _node.IdentityID = nodes[0]
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ return _node, _spec
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.AuthIdentityChannel.Create().
+// SetCreatedAt(v).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.AuthIdentityChannelUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *AuthIdentityChannelCreate) OnConflict(opts ...sql.ConflictOption) *AuthIdentityChannelUpsertOne {
+ _c.conflict = opts
+ return &AuthIdentityChannelUpsertOne{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.AuthIdentityChannel.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *AuthIdentityChannelCreate) OnConflictColumns(columns ...string) *AuthIdentityChannelUpsertOne {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &AuthIdentityChannelUpsertOne{
+ create: _c,
+ }
+}
+
+type (
+ // AuthIdentityChannelUpsertOne is the builder for "upsert"-ing
+ // one AuthIdentityChannel node.
+ AuthIdentityChannelUpsertOne struct {
+ create *AuthIdentityChannelCreate
+ }
+
+ // AuthIdentityChannelUpsert is the "OnConflict" setter.
+ AuthIdentityChannelUpsert struct {
+ *sql.UpdateSet
+ }
+)
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *AuthIdentityChannelUpsert) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpsert {
+ u.Set(authidentitychannel.FieldUpdatedAt, v)
+ return u
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsert) UpdateUpdatedAt() *AuthIdentityChannelUpsert {
+ u.SetExcluded(authidentitychannel.FieldUpdatedAt)
+ return u
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (u *AuthIdentityChannelUpsert) SetIdentityID(v int64) *AuthIdentityChannelUpsert {
+ u.Set(authidentitychannel.FieldIdentityID, v)
+ return u
+}
+
+// UpdateIdentityID sets the "identity_id" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsert) UpdateIdentityID() *AuthIdentityChannelUpsert {
+ u.SetExcluded(authidentitychannel.FieldIdentityID)
+ return u
+}
+
+// SetProviderType sets the "provider_type" field.
+func (u *AuthIdentityChannelUpsert) SetProviderType(v string) *AuthIdentityChannelUpsert {
+ u.Set(authidentitychannel.FieldProviderType, v)
+ return u
+}
+
+// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsert) UpdateProviderType() *AuthIdentityChannelUpsert {
+ u.SetExcluded(authidentitychannel.FieldProviderType)
+ return u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *AuthIdentityChannelUpsert) SetProviderKey(v string) *AuthIdentityChannelUpsert {
+ u.Set(authidentitychannel.FieldProviderKey, v)
+ return u
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsert) UpdateProviderKey() *AuthIdentityChannelUpsert {
+ u.SetExcluded(authidentitychannel.FieldProviderKey)
+ return u
+}
+
+// SetChannel sets the "channel" field.
+func (u *AuthIdentityChannelUpsert) SetChannel(v string) *AuthIdentityChannelUpsert {
+ u.Set(authidentitychannel.FieldChannel, v)
+ return u
+}
+
+// UpdateChannel sets the "channel" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsert) UpdateChannel() *AuthIdentityChannelUpsert {
+ u.SetExcluded(authidentitychannel.FieldChannel)
+ return u
+}
+
+// SetChannelAppID sets the "channel_app_id" field.
+func (u *AuthIdentityChannelUpsert) SetChannelAppID(v string) *AuthIdentityChannelUpsert {
+ u.Set(authidentitychannel.FieldChannelAppID, v)
+ return u
+}
+
+// UpdateChannelAppID sets the "channel_app_id" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsert) UpdateChannelAppID() *AuthIdentityChannelUpsert {
+ u.SetExcluded(authidentitychannel.FieldChannelAppID)
+ return u
+}
+
+// SetChannelSubject sets the "channel_subject" field.
+func (u *AuthIdentityChannelUpsert) SetChannelSubject(v string) *AuthIdentityChannelUpsert {
+ u.Set(authidentitychannel.FieldChannelSubject, v)
+ return u
+}
+
+// UpdateChannelSubject sets the "channel_subject" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsert) UpdateChannelSubject() *AuthIdentityChannelUpsert {
+ u.SetExcluded(authidentitychannel.FieldChannelSubject)
+ return u
+}
+
+// SetMetadata sets the "metadata" field.
+func (u *AuthIdentityChannelUpsert) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpsert {
+ u.Set(authidentitychannel.FieldMetadata, v)
+ return u
+}
+
+// UpdateMetadata sets the "metadata" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsert) UpdateMetadata() *AuthIdentityChannelUpsert {
+ u.SetExcluded(authidentitychannel.FieldMetadata)
+ return u
+}
+
+// UpdateNewValues updates the mutable fields using the new values that were set on create.
+// Using this option is equivalent to using:
+//
+// client.AuthIdentityChannel.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *AuthIdentityChannelUpsertOne) UpdateNewValues() *AuthIdentityChannelUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ if _, exists := u.create.mutation.CreatedAt(); exists {
+ s.SetIgnore(authidentitychannel.FieldCreatedAt)
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.AuthIdentityChannel.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *AuthIdentityChannelUpsertOne) Ignore() *AuthIdentityChannelUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *AuthIdentityChannelUpsertOne) DoNothing() *AuthIdentityChannelUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the AuthIdentityChannelCreate.OnConflict
+// documentation for more info.
+func (u *AuthIdentityChannelUpsertOne) Update(set func(*AuthIdentityChannelUpsert)) *AuthIdentityChannelUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&AuthIdentityChannelUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *AuthIdentityChannelUpsertOne) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertOne) UpdateUpdatedAt() *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (u *AuthIdentityChannelUpsertOne) SetIdentityID(v int64) *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetIdentityID(v)
+ })
+}
+
+// UpdateIdentityID sets the "identity_id" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertOne) UpdateIdentityID() *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateIdentityID()
+ })
+}
+
+// SetProviderType sets the "provider_type" field.
+func (u *AuthIdentityChannelUpsertOne) SetProviderType(v string) *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetProviderType(v)
+ })
+}
+
+// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertOne) UpdateProviderType() *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateProviderType()
+ })
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *AuthIdentityChannelUpsertOne) SetProviderKey(v string) *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetProviderKey(v)
+ })
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertOne) UpdateProviderKey() *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateProviderKey()
+ })
+}
+
+// SetChannel sets the "channel" field.
+func (u *AuthIdentityChannelUpsertOne) SetChannel(v string) *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetChannel(v)
+ })
+}
+
+// UpdateChannel sets the "channel" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertOne) UpdateChannel() *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateChannel()
+ })
+}
+
+// SetChannelAppID sets the "channel_app_id" field.
+func (u *AuthIdentityChannelUpsertOne) SetChannelAppID(v string) *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetChannelAppID(v)
+ })
+}
+
+// UpdateChannelAppID sets the "channel_app_id" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertOne) UpdateChannelAppID() *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateChannelAppID()
+ })
+}
+
+// SetChannelSubject sets the "channel_subject" field.
+func (u *AuthIdentityChannelUpsertOne) SetChannelSubject(v string) *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetChannelSubject(v)
+ })
+}
+
+// UpdateChannelSubject sets the "channel_subject" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertOne) UpdateChannelSubject() *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateChannelSubject()
+ })
+}
+
+// SetMetadata sets the "metadata" field.
+func (u *AuthIdentityChannelUpsertOne) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetMetadata(v)
+ })
+}
+
+// UpdateMetadata sets the "metadata" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertOne) UpdateMetadata() *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateMetadata()
+ })
+}
+
+// Exec executes the query.
+func (u *AuthIdentityChannelUpsertOne) Exec(ctx context.Context) error {
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for AuthIdentityChannelCreate.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *AuthIdentityChannelUpsertOne) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// Exec executes the UPSERT query and returns the inserted/updated ID.
+func (u *AuthIdentityChannelUpsertOne) ID(ctx context.Context) (id int64, err error) {
+ node, err := u.create.Save(ctx)
+ if err != nil {
+ return id, err
+ }
+ return node.ID, nil
+}
+
+// IDX is like ID, but panics if an error occurs.
+func (u *AuthIdentityChannelUpsertOne) IDX(ctx context.Context) int64 {
+ id, err := u.ID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// AuthIdentityChannelCreateBulk is the builder for creating many AuthIdentityChannel entities in bulk.
+type AuthIdentityChannelCreateBulk struct {
+ config
+ err error
+ builders []*AuthIdentityChannelCreate
+ conflict []sql.ConflictOption
+}
+
+// Save creates the AuthIdentityChannel entities in the database.
+func (_c *AuthIdentityChannelCreateBulk) Save(ctx context.Context) ([]*AuthIdentityChannel, error) {
+ if _c.err != nil {
+ return nil, _c.err
+ }
+ specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
+ nodes := make([]*AuthIdentityChannel, len(_c.builders))
+ mutators := make([]Mutator, len(_c.builders))
+ for i := range _c.builders {
+ func(i int, root context.Context) {
+ builder := _c.builders[i]
+ builder.defaults()
+ var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
+ mutation, ok := m.(*AuthIdentityChannelMutation)
+ if !ok {
+ return nil, fmt.Errorf("unexpected mutation type %T", m)
+ }
+ if err := builder.check(); err != nil {
+ return nil, err
+ }
+ builder.mutation = mutation
+ var err error
+ nodes[i], specs[i] = builder.createSpec()
+ if i < len(mutators)-1 {
+ _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
+ } else {
+ spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
+ spec.OnConflict = _c.conflict
+ // Invoke the actual operation on the latest mutation in the chain.
+ if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ }
+ }
+ if err != nil {
+ return nil, err
+ }
+ mutation.id = &nodes[i].ID
+ if specs[i].ID.Value != nil {
+ id := specs[i].ID.Value.(int64)
+ nodes[i].ID = int64(id)
+ }
+ mutation.done = true
+ return nodes[i], nil
+ })
+ for i := len(builder.hooks) - 1; i >= 0; i-- {
+ mut = builder.hooks[i](mut)
+ }
+ mutators[i] = mut
+ }(i, ctx)
+ }
+ if len(mutators) > 0 {
+ if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_c *AuthIdentityChannelCreateBulk) SaveX(ctx context.Context) []*AuthIdentityChannel {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *AuthIdentityChannelCreateBulk) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *AuthIdentityChannelCreateBulk) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.AuthIdentityChannel.CreateBulk(builders...).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.AuthIdentityChannelUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *AuthIdentityChannelCreateBulk) OnConflict(opts ...sql.ConflictOption) *AuthIdentityChannelUpsertBulk {
+ _c.conflict = opts
+ return &AuthIdentityChannelUpsertBulk{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.AuthIdentityChannel.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *AuthIdentityChannelCreateBulk) OnConflictColumns(columns ...string) *AuthIdentityChannelUpsertBulk {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &AuthIdentityChannelUpsertBulk{
+ create: _c,
+ }
+}
+
+// AuthIdentityChannelUpsertBulk is the builder for "upsert"-ing
+// a bulk of AuthIdentityChannel nodes.
+type AuthIdentityChannelUpsertBulk struct {
+ create *AuthIdentityChannelCreateBulk
+}
+
+// UpdateNewValues updates the mutable fields using the new values that
+// were set on create. Using this option is equivalent to using:
+//
+// client.AuthIdentityChannel.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *AuthIdentityChannelUpsertBulk) UpdateNewValues() *AuthIdentityChannelUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ for _, b := range u.create.builders {
+ if _, exists := b.mutation.CreatedAt(); exists {
+ s.SetIgnore(authidentitychannel.FieldCreatedAt)
+ }
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.AuthIdentityChannel.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *AuthIdentityChannelUpsertBulk) Ignore() *AuthIdentityChannelUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *AuthIdentityChannelUpsertBulk) DoNothing() *AuthIdentityChannelUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the AuthIdentityChannelCreateBulk.OnConflict
+// documentation for more info.
+func (u *AuthIdentityChannelUpsertBulk) Update(set func(*AuthIdentityChannelUpsert)) *AuthIdentityChannelUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&AuthIdentityChannelUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *AuthIdentityChannelUpsertBulk) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertBulk) UpdateUpdatedAt() *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (u *AuthIdentityChannelUpsertBulk) SetIdentityID(v int64) *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetIdentityID(v)
+ })
+}
+
+// UpdateIdentityID sets the "identity_id" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertBulk) UpdateIdentityID() *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateIdentityID()
+ })
+}
+
+// SetProviderType sets the "provider_type" field.
+func (u *AuthIdentityChannelUpsertBulk) SetProviderType(v string) *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetProviderType(v)
+ })
+}
+
+// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertBulk) UpdateProviderType() *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateProviderType()
+ })
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *AuthIdentityChannelUpsertBulk) SetProviderKey(v string) *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetProviderKey(v)
+ })
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertBulk) UpdateProviderKey() *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateProviderKey()
+ })
+}
+
+// SetChannel sets the "channel" field.
+func (u *AuthIdentityChannelUpsertBulk) SetChannel(v string) *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetChannel(v)
+ })
+}
+
+// UpdateChannel sets the "channel" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertBulk) UpdateChannel() *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateChannel()
+ })
+}
+
+// SetChannelAppID sets the "channel_app_id" field.
+func (u *AuthIdentityChannelUpsertBulk) SetChannelAppID(v string) *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetChannelAppID(v)
+ })
+}
+
+// UpdateChannelAppID sets the "channel_app_id" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertBulk) UpdateChannelAppID() *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateChannelAppID()
+ })
+}
+
+// SetChannelSubject sets the "channel_subject" field.
+func (u *AuthIdentityChannelUpsertBulk) SetChannelSubject(v string) *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetChannelSubject(v)
+ })
+}
+
+// UpdateChannelSubject sets the "channel_subject" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertBulk) UpdateChannelSubject() *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateChannelSubject()
+ })
+}
+
+// SetMetadata sets the "metadata" field.
+func (u *AuthIdentityChannelUpsertBulk) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetMetadata(v)
+ })
+}
+
+// UpdateMetadata sets the "metadata" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertBulk) UpdateMetadata() *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateMetadata()
+ })
+}
+
+// Exec executes the query.
+func (u *AuthIdentityChannelUpsertBulk) Exec(ctx context.Context) error {
+ if u.create.err != nil {
+ return u.create.err
+ }
+ for i, b := range u.create.builders {
+ if len(b.conflict) != 0 {
+ return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the AuthIdentityChannelCreateBulk instead", i)
+ }
+ }
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for AuthIdentityChannelCreateBulk.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *AuthIdentityChannelUpsertBulk) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/authidentitychannel_delete.go b/backend/ent/authidentitychannel_delete.go
new file mode 100644
index 00000000..1a4acac5
--- /dev/null
+++ b/backend/ent/authidentitychannel_delete.go
@@ -0,0 +1,88 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// AuthIdentityChannelDelete is the builder for deleting a AuthIdentityChannel entity.
+type AuthIdentityChannelDelete struct {
+ config
+ hooks []Hook
+ mutation *AuthIdentityChannelMutation
+}
+
+// Where appends a list predicates to the AuthIdentityChannelDelete builder.
+func (_d *AuthIdentityChannelDelete) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelDelete {
+ _d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query and returns how many vertices were deleted.
+func (_d *AuthIdentityChannelDelete) Exec(ctx context.Context) (int, error) {
+ return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *AuthIdentityChannelDelete) ExecX(ctx context.Context) int {
+ n, err := _d.Exec(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return n
+}
+
+func (_d *AuthIdentityChannelDelete) sqlExec(ctx context.Context) (int, error) {
+ _spec := sqlgraph.NewDeleteSpec(authidentitychannel.Table, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64))
+ if ps := _d.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
+ if err != nil && sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ _d.mutation.done = true
+ return affected, err
+}
+
+// AuthIdentityChannelDeleteOne is the builder for deleting a single AuthIdentityChannel entity.
+type AuthIdentityChannelDeleteOne struct {
+ _d *AuthIdentityChannelDelete
+}
+
+// Where appends a list predicates to the AuthIdentityChannelDelete builder.
+func (_d *AuthIdentityChannelDeleteOne) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelDeleteOne {
+ _d._d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query.
+func (_d *AuthIdentityChannelDeleteOne) Exec(ctx context.Context) error {
+ n, err := _d._d.Exec(ctx)
+ switch {
+ case err != nil:
+ return err
+ case n == 0:
+ return &NotFoundError{authidentitychannel.Label}
+ default:
+ return nil
+ }
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *AuthIdentityChannelDeleteOne) ExecX(ctx context.Context) {
+ if err := _d.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/authidentitychannel_query.go b/backend/ent/authidentitychannel_query.go
new file mode 100644
index 00000000..7a202b7f
--- /dev/null
+++ b/backend/ent/authidentitychannel_query.go
@@ -0,0 +1,643 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "fmt"
+ "math"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// AuthIdentityChannelQuery is the builder for querying AuthIdentityChannel entities.
+type AuthIdentityChannelQuery struct {
+ config
+ ctx *QueryContext
+ order []authidentitychannel.OrderOption
+ inters []Interceptor
+ predicates []predicate.AuthIdentityChannel
+ withIdentity *AuthIdentityQuery
+ modifiers []func(*sql.Selector)
+ // intermediate query (i.e. traversal path).
+ sql *sql.Selector
+ path func(context.Context) (*sql.Selector, error)
+}
+
+// Where adds a new predicate for the AuthIdentityChannelQuery builder.
+func (_q *AuthIdentityChannelQuery) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelQuery {
+ _q.predicates = append(_q.predicates, ps...)
+ return _q
+}
+
+// Limit the number of records to be returned by this query.
+func (_q *AuthIdentityChannelQuery) Limit(limit int) *AuthIdentityChannelQuery {
+ _q.ctx.Limit = &limit
+ return _q
+}
+
+// Offset to start from.
+func (_q *AuthIdentityChannelQuery) Offset(offset int) *AuthIdentityChannelQuery {
+ _q.ctx.Offset = &offset
+ return _q
+}
+
+// Unique configures the query builder to filter duplicate records on query.
+// By default, unique is set to true, and can be disabled using this method.
+func (_q *AuthIdentityChannelQuery) Unique(unique bool) *AuthIdentityChannelQuery {
+ _q.ctx.Unique = &unique
+ return _q
+}
+
+// Order specifies how the records should be ordered.
+func (_q *AuthIdentityChannelQuery) Order(o ...authidentitychannel.OrderOption) *AuthIdentityChannelQuery {
+ _q.order = append(_q.order, o...)
+ return _q
+}
+
+// QueryIdentity chains the current query on the "identity" edge.
+func (_q *AuthIdentityChannelQuery) QueryIdentity() *AuthIdentityQuery {
+ query := (&AuthIdentityClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(authidentitychannel.Table, authidentitychannel.FieldID, selector),
+ sqlgraph.To(authidentity.Table, authidentity.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, authidentitychannel.IdentityTable, authidentitychannel.IdentityColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// First returns the first AuthIdentityChannel entity from the query.
+// Returns a *NotFoundError when no AuthIdentityChannel was found.
+func (_q *AuthIdentityChannelQuery) First(ctx context.Context) (*AuthIdentityChannel, error) {
+ nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
+ if err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nil, &NotFoundError{authidentitychannel.Label}
+ }
+ return nodes[0], nil
+}
+
+// FirstX is like First, but panics if an error occurs.
+func (_q *AuthIdentityChannelQuery) FirstX(ctx context.Context) *AuthIdentityChannel {
+ node, err := _q.First(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return node
+}
+
+// FirstID returns the first AuthIdentityChannel ID from the query.
+// Returns a *NotFoundError when no AuthIdentityChannel ID was found.
+func (_q *AuthIdentityChannelQuery) FirstID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
+ return
+ }
+ if len(ids) == 0 {
+ err = &NotFoundError{authidentitychannel.Label}
+ return
+ }
+ return ids[0], nil
+}
+
+// FirstIDX is like FirstID, but panics if an error occurs.
+func (_q *AuthIdentityChannelQuery) FirstIDX(ctx context.Context) int64 {
+ id, err := _q.FirstID(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return id
+}
+
+// Only returns a single AuthIdentityChannel entity found by the query, ensuring it only returns one.
+// Returns a *NotSingularError when more than one AuthIdentityChannel entity is found.
+// Returns a *NotFoundError when no AuthIdentityChannel entities are found.
+func (_q *AuthIdentityChannelQuery) Only(ctx context.Context) (*AuthIdentityChannel, error) {
+ nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
+ if err != nil {
+ return nil, err
+ }
+ switch len(nodes) {
+ case 1:
+ return nodes[0], nil
+ case 0:
+ return nil, &NotFoundError{authidentitychannel.Label}
+ default:
+ return nil, &NotSingularError{authidentitychannel.Label}
+ }
+}
+
+// OnlyX is like Only, but panics if an error occurs.
+func (_q *AuthIdentityChannelQuery) OnlyX(ctx context.Context) *AuthIdentityChannel {
+ node, err := _q.Only(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// OnlyID is like Only, but returns the only AuthIdentityChannel ID in the query.
+// Returns a *NotSingularError when more than one AuthIdentityChannel ID is found.
+// Returns a *NotFoundError when no entities are found.
+func (_q *AuthIdentityChannelQuery) OnlyID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
+ return
+ }
+ switch len(ids) {
+ case 1:
+ id = ids[0]
+ case 0:
+ err = &NotFoundError{authidentitychannel.Label}
+ default:
+ err = &NotSingularError{authidentitychannel.Label}
+ }
+ return
+}
+
+// OnlyIDX is like OnlyID, but panics if an error occurs.
+func (_q *AuthIdentityChannelQuery) OnlyIDX(ctx context.Context) int64 {
+ id, err := _q.OnlyID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// All executes the query and returns a list of AuthIdentityChannels.
+func (_q *AuthIdentityChannelQuery) All(ctx context.Context) ([]*AuthIdentityChannel, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ qr := querierAll[[]*AuthIdentityChannel, *AuthIdentityChannelQuery]()
+ return withInterceptors[[]*AuthIdentityChannel](ctx, _q, qr, _q.inters)
+}
+
+// AllX is like All, but panics if an error occurs.
+func (_q *AuthIdentityChannelQuery) AllX(ctx context.Context) []*AuthIdentityChannel {
+ nodes, err := _q.All(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return nodes
+}
+
+// IDs executes the query and returns a list of AuthIdentityChannel IDs.
+func (_q *AuthIdentityChannelQuery) IDs(ctx context.Context) (ids []int64, err error) {
+ if _q.ctx.Unique == nil && _q.path != nil {
+ _q.Unique(true)
+ }
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
+ if err = _q.Select(authidentitychannel.FieldID).Scan(ctx, &ids); err != nil {
+ return nil, err
+ }
+ return ids, nil
+}
+
+// IDsX is like IDs, but panics if an error occurs.
+func (_q *AuthIdentityChannelQuery) IDsX(ctx context.Context) []int64 {
+ ids, err := _q.IDs(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return ids
+}
+
+// Count returns the count of the given query.
+func (_q *AuthIdentityChannelQuery) Count(ctx context.Context) (int, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return 0, err
+ }
+ return withInterceptors[int](ctx, _q, querierCount[*AuthIdentityChannelQuery](), _q.inters)
+}
+
+// CountX is like Count, but panics if an error occurs.
+func (_q *AuthIdentityChannelQuery) CountX(ctx context.Context) int {
+ count, err := _q.Count(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return count
+}
+
+// Exist returns true if the query has elements in the graph.
+func (_q *AuthIdentityChannelQuery) Exist(ctx context.Context) (bool, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
+ switch _, err := _q.FirstID(ctx); {
+ case IsNotFound(err):
+ return false, nil
+ case err != nil:
+ return false, fmt.Errorf("ent: check existence: %w", err)
+ default:
+ return true, nil
+ }
+}
+
+// ExistX is like Exist, but panics if an error occurs.
+func (_q *AuthIdentityChannelQuery) ExistX(ctx context.Context) bool {
+ exist, err := _q.Exist(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return exist
+}
+
+// Clone returns a duplicate of the AuthIdentityChannelQuery builder, including all associated steps. It can be
+// used to prepare common query builders and use them differently after the clone is made.
+func (_q *AuthIdentityChannelQuery) Clone() *AuthIdentityChannelQuery {
+ if _q == nil {
+ return nil
+ }
+ return &AuthIdentityChannelQuery{
+ config: _q.config,
+ ctx: _q.ctx.Clone(),
+ order: append([]authidentitychannel.OrderOption{}, _q.order...),
+ inters: append([]Interceptor{}, _q.inters...),
+ predicates: append([]predicate.AuthIdentityChannel{}, _q.predicates...),
+ withIdentity: _q.withIdentity.Clone(),
+ // clone intermediate query.
+ sql: _q.sql.Clone(),
+ path: _q.path,
+ }
+}
+
+// WithIdentity tells the query-builder to eager-load the nodes that are connected to
+// the "identity" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *AuthIdentityChannelQuery) WithIdentity(opts ...func(*AuthIdentityQuery)) *AuthIdentityChannelQuery {
+ query := (&AuthIdentityClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withIdentity = query
+ return _q
+}
+
+// GroupBy is used to group vertices by one or more fields/columns.
+// It is often used with aggregate functions, like: count, max, mean, min, sum.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// Count int `json:"count,omitempty"`
+// }
+//
+// client.AuthIdentityChannel.Query().
+// GroupBy(authidentitychannel.FieldCreatedAt).
+// Aggregate(ent.Count()).
+// Scan(ctx, &v)
+func (_q *AuthIdentityChannelQuery) GroupBy(field string, fields ...string) *AuthIdentityChannelGroupBy {
+ _q.ctx.Fields = append([]string{field}, fields...)
+ grbuild := &AuthIdentityChannelGroupBy{build: _q}
+ grbuild.flds = &_q.ctx.Fields
+ grbuild.label = authidentitychannel.Label
+ grbuild.scan = grbuild.Scan
+ return grbuild
+}
+
+// Select allows the selection one or more fields/columns for the given query,
+// instead of selecting all fields in the entity.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// }
+//
+// client.AuthIdentityChannel.Query().
+// Select(authidentitychannel.FieldCreatedAt).
+// Scan(ctx, &v)
+func (_q *AuthIdentityChannelQuery) Select(fields ...string) *AuthIdentityChannelSelect {
+ _q.ctx.Fields = append(_q.ctx.Fields, fields...)
+ sbuild := &AuthIdentityChannelSelect{AuthIdentityChannelQuery: _q}
+ sbuild.label = authidentitychannel.Label
+ sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
+ return sbuild
+}
+
+// Aggregate returns a AuthIdentityChannelSelect configured with the given aggregations.
+func (_q *AuthIdentityChannelQuery) Aggregate(fns ...AggregateFunc) *AuthIdentityChannelSelect {
+ return _q.Select().Aggregate(fns...)
+}
+
+func (_q *AuthIdentityChannelQuery) prepareQuery(ctx context.Context) error {
+ for _, inter := range _q.inters {
+ if inter == nil {
+ return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
+ }
+ if trv, ok := inter.(Traverser); ok {
+ if err := trv.Traverse(ctx, _q); err != nil {
+ return err
+ }
+ }
+ }
+ for _, f := range _q.ctx.Fields {
+ if !authidentitychannel.ValidColumn(f) {
+ return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ }
+ if _q.path != nil {
+ prev, err := _q.path(ctx)
+ if err != nil {
+ return err
+ }
+ _q.sql = prev
+ }
+ return nil
+}
+
+func (_q *AuthIdentityChannelQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*AuthIdentityChannel, error) {
+ var (
+ nodes = []*AuthIdentityChannel{}
+ _spec = _q.querySpec()
+ loadedTypes = [1]bool{
+ _q.withIdentity != nil,
+ }
+ )
+ _spec.ScanValues = func(columns []string) ([]any, error) {
+ return (*AuthIdentityChannel).scanValues(nil, columns)
+ }
+ _spec.Assign = func(columns []string, values []any) error {
+ node := &AuthIdentityChannel{config: _q.config}
+ nodes = append(nodes, node)
+ node.Edges.loadedTypes = loadedTypes
+ return node.assignValues(columns, values)
+ }
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ for i := range hooks {
+ hooks[i](ctx, _spec)
+ }
+ if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nodes, nil
+ }
+ if query := _q.withIdentity; query != nil {
+ if err := _q.loadIdentity(ctx, query, nodes, nil,
+ func(n *AuthIdentityChannel, e *AuthIdentity) { n.Edges.Identity = e }); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+func (_q *AuthIdentityChannelQuery) loadIdentity(ctx context.Context, query *AuthIdentityQuery, nodes []*AuthIdentityChannel, init func(*AuthIdentityChannel), assign func(*AuthIdentityChannel, *AuthIdentity)) error {
+ ids := make([]int64, 0, len(nodes))
+ nodeids := make(map[int64][]*AuthIdentityChannel)
+ for i := range nodes {
+ fk := nodes[i].IdentityID
+ if _, ok := nodeids[fk]; !ok {
+ ids = append(ids, fk)
+ }
+ nodeids[fk] = append(nodeids[fk], nodes[i])
+ }
+ if len(ids) == 0 {
+ return nil
+ }
+ query.Where(authidentity.IDIn(ids...))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ nodes, ok := nodeids[n.ID]
+ if !ok {
+ return fmt.Errorf(`unexpected foreign-key "identity_id" returned %v`, n.ID)
+ }
+ for i := range nodes {
+ assign(nodes[i], n)
+ }
+ }
+ return nil
+}
+
+func (_q *AuthIdentityChannelQuery) sqlCount(ctx context.Context) (int, error) {
+ _spec := _q.querySpec()
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ _spec.Node.Columns = _q.ctx.Fields
+ if len(_q.ctx.Fields) > 0 {
+ _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
+ }
+ return sqlgraph.CountNodes(ctx, _q.driver, _spec)
+}
+
+func (_q *AuthIdentityChannelQuery) querySpec() *sqlgraph.QuerySpec {
+ _spec := sqlgraph.NewQuerySpec(authidentitychannel.Table, authidentitychannel.Columns, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64))
+ _spec.From = _q.sql
+ if unique := _q.ctx.Unique; unique != nil {
+ _spec.Unique = *unique
+ } else if _q.path != nil {
+ _spec.Unique = true
+ }
+ if fields := _q.ctx.Fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, authidentitychannel.FieldID)
+ for i := range fields {
+ if fields[i] != authidentitychannel.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, fields[i])
+ }
+ }
+ if _q.withIdentity != nil {
+ _spec.Node.AddColumnOnce(authidentitychannel.FieldIdentityID)
+ }
+ }
+ if ps := _q.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ _spec.Limit = *limit
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ _spec.Offset = *offset
+ }
+ if ps := _q.order; len(ps) > 0 {
+ _spec.Order = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ return _spec
+}
+
+func (_q *AuthIdentityChannelQuery) sqlQuery(ctx context.Context) *sql.Selector {
+ builder := sql.Dialect(_q.driver.Dialect())
+ t1 := builder.Table(authidentitychannel.Table)
+ columns := _q.ctx.Fields
+ if len(columns) == 0 {
+ columns = authidentitychannel.Columns
+ }
+ selector := builder.Select(t1.Columns(columns...)...).From(t1)
+ if _q.sql != nil {
+ selector = _q.sql
+ selector.Select(selector.Columns(columns...)...)
+ }
+ if _q.ctx.Unique != nil && *_q.ctx.Unique {
+ selector.Distinct()
+ }
+ for _, m := range _q.modifiers {
+ m(selector)
+ }
+ for _, p := range _q.predicates {
+ p(selector)
+ }
+ for _, p := range _q.order {
+ p(selector)
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ // limit is mandatory for offset clause. We start
+ // with default value, and override it below if needed.
+ selector.Offset(*offset).Limit(math.MaxInt32)
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ selector.Limit(*limit)
+ }
+ return selector
+}
+
+// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
+// updated, deleted or "selected ... for update" by other sessions, until the transaction is
+// either committed or rolled-back.
+func (_q *AuthIdentityChannelQuery) ForUpdate(opts ...sql.LockOption) *AuthIdentityChannelQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForUpdate(opts...)
+ })
+ return _q
+}
+
+// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
+// on any rows that are read. Other sessions can read the rows, but cannot modify them
+// until your transaction commits.
+func (_q *AuthIdentityChannelQuery) ForShare(opts ...sql.LockOption) *AuthIdentityChannelQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForShare(opts...)
+ })
+ return _q
+}
+
+// AuthIdentityChannelGroupBy is the group-by builder for AuthIdentityChannel entities.
+type AuthIdentityChannelGroupBy struct {
+ selector
+ build *AuthIdentityChannelQuery
+}
+
+// Aggregate adds the given aggregation functions to the group-by query.
+func (_g *AuthIdentityChannelGroupBy) Aggregate(fns ...AggregateFunc) *AuthIdentityChannelGroupBy {
+ _g.fns = append(_g.fns, fns...)
+ return _g
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_g *AuthIdentityChannelGroupBy) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
+ if err := _g.build.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*AuthIdentityChannelQuery, *AuthIdentityChannelGroupBy](ctx, _g.build, _g, _g.build.inters, v)
+}
+
+func (_g *AuthIdentityChannelGroupBy) sqlScan(ctx context.Context, root *AuthIdentityChannelQuery, v any) error {
+ selector := root.sqlQuery(ctx).Select()
+ aggregation := make([]string, 0, len(_g.fns))
+ for _, fn := range _g.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ if len(selector.SelectedColumns()) == 0 {
+ columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
+ for _, f := range *_g.flds {
+ columns = append(columns, selector.C(f))
+ }
+ columns = append(columns, aggregation...)
+ selector.Select(columns...)
+ }
+ selector.GroupBy(selector.Columns(*_g.flds...)...)
+ if err := selector.Err(); err != nil {
+ return err
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
+
+// AuthIdentityChannelSelect is the builder for selecting fields of AuthIdentityChannel entities.
+type AuthIdentityChannelSelect struct {
+ *AuthIdentityChannelQuery
+ selector
+}
+
+// Aggregate adds the given aggregation functions to the selector query.
+func (_s *AuthIdentityChannelSelect) Aggregate(fns ...AggregateFunc) *AuthIdentityChannelSelect {
+ _s.fns = append(_s.fns, fns...)
+ return _s
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_s *AuthIdentityChannelSelect) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
+ if err := _s.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*AuthIdentityChannelQuery, *AuthIdentityChannelSelect](ctx, _s.AuthIdentityChannelQuery, _s, _s.inters, v)
+}
+
+func (_s *AuthIdentityChannelSelect) sqlScan(ctx context.Context, root *AuthIdentityChannelQuery, v any) error {
+ selector := root.sqlQuery(ctx)
+ aggregation := make([]string, 0, len(_s.fns))
+ for _, fn := range _s.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ switch n := len(*_s.selector.flds); {
+ case n == 0 && len(aggregation) > 0:
+ selector.Select(aggregation...)
+ case n != 0 && len(aggregation) > 0:
+ selector.AppendSelect(aggregation...)
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _s.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
diff --git a/backend/ent/authidentitychannel_update.go b/backend/ent/authidentitychannel_update.go
new file mode 100644
index 00000000..b550c454
--- /dev/null
+++ b/backend/ent/authidentitychannel_update.go
@@ -0,0 +1,581 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// AuthIdentityChannelUpdate is the builder for updating AuthIdentityChannel entities.
+type AuthIdentityChannelUpdate struct {
+ config
+ hooks []Hook
+ mutation *AuthIdentityChannelMutation
+}
+
+// Where appends a list predicates to the AuthIdentityChannelUpdate builder.
+func (_u *AuthIdentityChannelUpdate) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelUpdate {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *AuthIdentityChannelUpdate) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpdate {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (_u *AuthIdentityChannelUpdate) SetIdentityID(v int64) *AuthIdentityChannelUpdate {
+ _u.mutation.SetIdentityID(v)
+ return _u
+}
+
+// SetNillableIdentityID sets the "identity_id" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdate) SetNillableIdentityID(v *int64) *AuthIdentityChannelUpdate {
+ if v != nil {
+ _u.SetIdentityID(*v)
+ }
+ return _u
+}
+
+// SetProviderType sets the "provider_type" field.
+func (_u *AuthIdentityChannelUpdate) SetProviderType(v string) *AuthIdentityChannelUpdate {
+ _u.mutation.SetProviderType(v)
+ return _u
+}
+
+// SetNillableProviderType sets the "provider_type" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdate) SetNillableProviderType(v *string) *AuthIdentityChannelUpdate {
+ if v != nil {
+ _u.SetProviderType(*v)
+ }
+ return _u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_u *AuthIdentityChannelUpdate) SetProviderKey(v string) *AuthIdentityChannelUpdate {
+ _u.mutation.SetProviderKey(v)
+ return _u
+}
+
+// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdate) SetNillableProviderKey(v *string) *AuthIdentityChannelUpdate {
+ if v != nil {
+ _u.SetProviderKey(*v)
+ }
+ return _u
+}
+
+// SetChannel sets the "channel" field.
+func (_u *AuthIdentityChannelUpdate) SetChannel(v string) *AuthIdentityChannelUpdate {
+ _u.mutation.SetChannel(v)
+ return _u
+}
+
+// SetNillableChannel sets the "channel" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdate) SetNillableChannel(v *string) *AuthIdentityChannelUpdate {
+ if v != nil {
+ _u.SetChannel(*v)
+ }
+ return _u
+}
+
+// SetChannelAppID sets the "channel_app_id" field.
+func (_u *AuthIdentityChannelUpdate) SetChannelAppID(v string) *AuthIdentityChannelUpdate {
+ _u.mutation.SetChannelAppID(v)
+ return _u
+}
+
+// SetNillableChannelAppID sets the "channel_app_id" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdate) SetNillableChannelAppID(v *string) *AuthIdentityChannelUpdate {
+ if v != nil {
+ _u.SetChannelAppID(*v)
+ }
+ return _u
+}
+
+// SetChannelSubject sets the "channel_subject" field.
+func (_u *AuthIdentityChannelUpdate) SetChannelSubject(v string) *AuthIdentityChannelUpdate {
+ _u.mutation.SetChannelSubject(v)
+ return _u
+}
+
+// SetNillableChannelSubject sets the "channel_subject" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdate) SetNillableChannelSubject(v *string) *AuthIdentityChannelUpdate {
+ if v != nil {
+ _u.SetChannelSubject(*v)
+ }
+ return _u
+}
+
+// SetMetadata sets the "metadata" field.
+func (_u *AuthIdentityChannelUpdate) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpdate {
+ _u.mutation.SetMetadata(v)
+ return _u
+}
+
+// SetIdentity sets the "identity" edge to the AuthIdentity entity.
+func (_u *AuthIdentityChannelUpdate) SetIdentity(v *AuthIdentity) *AuthIdentityChannelUpdate {
+ return _u.SetIdentityID(v.ID)
+}
+
+// Mutation returns the AuthIdentityChannelMutation object of the builder.
+func (_u *AuthIdentityChannelUpdate) Mutation() *AuthIdentityChannelMutation {
+ return _u.mutation
+}
+
+// ClearIdentity clears the "identity" edge to the AuthIdentity entity.
+func (_u *AuthIdentityChannelUpdate) ClearIdentity() *AuthIdentityChannelUpdate {
+ _u.mutation.ClearIdentity()
+ return _u
+}
+
+// Save executes the query and returns the number of nodes affected by the update operation.
+func (_u *AuthIdentityChannelUpdate) Save(ctx context.Context) (int, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *AuthIdentityChannelUpdate) SaveX(ctx context.Context) int {
+ affected, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return affected
+}
+
+// Exec executes the query.
+func (_u *AuthIdentityChannelUpdate) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *AuthIdentityChannelUpdate) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *AuthIdentityChannelUpdate) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := authidentitychannel.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *AuthIdentityChannelUpdate) check() error {
+ if v, ok := _u.mutation.ProviderType(); ok {
+ if err := authidentitychannel.ProviderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_type": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderKey(); ok {
+ if err := authidentitychannel.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_key": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Channel(); ok {
+ if err := authidentitychannel.ChannelValidator(v); err != nil {
+ return &ValidationError{Name: "channel", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ChannelAppID(); ok {
+ if err := authidentitychannel.ChannelAppIDValidator(v); err != nil {
+ return &ValidationError{Name: "channel_app_id", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_app_id": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ChannelSubject(); ok {
+ if err := authidentitychannel.ChannelSubjectValidator(v); err != nil {
+ return &ValidationError{Name: "channel_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_subject": %w`, err)}
+ }
+ }
+ if _u.mutation.IdentityCleared() && len(_u.mutation.IdentityIDs()) > 0 {
+ return errors.New(`ent: clearing a required unique edge "AuthIdentityChannel.identity"`)
+ }
+ return nil
+}
+
+func (_u *AuthIdentityChannelUpdate) sqlSave(ctx context.Context) (_node int, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(authidentitychannel.Table, authidentitychannel.Columns, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64))
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(authidentitychannel.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.ProviderType(); ok {
+ _spec.SetField(authidentitychannel.FieldProviderType, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderKey(); ok {
+ _spec.SetField(authidentitychannel.FieldProviderKey, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Channel(); ok {
+ _spec.SetField(authidentitychannel.FieldChannel, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ChannelAppID(); ok {
+ _spec.SetField(authidentitychannel.FieldChannelAppID, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ChannelSubject(); ok {
+ _spec.SetField(authidentitychannel.FieldChannelSubject, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Metadata(); ok {
+ _spec.SetField(authidentitychannel.FieldMetadata, field.TypeJSON, value)
+ }
+ if _u.mutation.IdentityCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentitychannel.IdentityTable,
+ Columns: []string{authidentitychannel.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.IdentityIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentitychannel.IdentityTable,
+ Columns: []string{authidentitychannel.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{authidentitychannel.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return 0, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
+
+// AuthIdentityChannelUpdateOne is the builder for updating a single AuthIdentityChannel entity.
+type AuthIdentityChannelUpdateOne struct {
+ config
+ fields []string
+ hooks []Hook
+ mutation *AuthIdentityChannelMutation
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *AuthIdentityChannelUpdateOne) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpdateOne {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (_u *AuthIdentityChannelUpdateOne) SetIdentityID(v int64) *AuthIdentityChannelUpdateOne {
+ _u.mutation.SetIdentityID(v)
+ return _u
+}
+
+// SetNillableIdentityID sets the "identity_id" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdateOne) SetNillableIdentityID(v *int64) *AuthIdentityChannelUpdateOne {
+ if v != nil {
+ _u.SetIdentityID(*v)
+ }
+ return _u
+}
+
+// SetProviderType sets the "provider_type" field.
+func (_u *AuthIdentityChannelUpdateOne) SetProviderType(v string) *AuthIdentityChannelUpdateOne {
+ _u.mutation.SetProviderType(v)
+ return _u
+}
+
+// SetNillableProviderType sets the "provider_type" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdateOne) SetNillableProviderType(v *string) *AuthIdentityChannelUpdateOne {
+ if v != nil {
+ _u.SetProviderType(*v)
+ }
+ return _u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_u *AuthIdentityChannelUpdateOne) SetProviderKey(v string) *AuthIdentityChannelUpdateOne {
+ _u.mutation.SetProviderKey(v)
+ return _u
+}
+
+// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdateOne) SetNillableProviderKey(v *string) *AuthIdentityChannelUpdateOne {
+ if v != nil {
+ _u.SetProviderKey(*v)
+ }
+ return _u
+}
+
+// SetChannel sets the "channel" field.
+func (_u *AuthIdentityChannelUpdateOne) SetChannel(v string) *AuthIdentityChannelUpdateOne {
+ _u.mutation.SetChannel(v)
+ return _u
+}
+
+// SetNillableChannel sets the "channel" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdateOne) SetNillableChannel(v *string) *AuthIdentityChannelUpdateOne {
+ if v != nil {
+ _u.SetChannel(*v)
+ }
+ return _u
+}
+
+// SetChannelAppID sets the "channel_app_id" field.
+func (_u *AuthIdentityChannelUpdateOne) SetChannelAppID(v string) *AuthIdentityChannelUpdateOne {
+ _u.mutation.SetChannelAppID(v)
+ return _u
+}
+
+// SetNillableChannelAppID sets the "channel_app_id" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdateOne) SetNillableChannelAppID(v *string) *AuthIdentityChannelUpdateOne {
+ if v != nil {
+ _u.SetChannelAppID(*v)
+ }
+ return _u
+}
+
+// SetChannelSubject sets the "channel_subject" field.
+func (_u *AuthIdentityChannelUpdateOne) SetChannelSubject(v string) *AuthIdentityChannelUpdateOne {
+ _u.mutation.SetChannelSubject(v)
+ return _u
+}
+
+// SetNillableChannelSubject sets the "channel_subject" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdateOne) SetNillableChannelSubject(v *string) *AuthIdentityChannelUpdateOne {
+ if v != nil {
+ _u.SetChannelSubject(*v)
+ }
+ return _u
+}
+
+// SetMetadata sets the "metadata" field.
+func (_u *AuthIdentityChannelUpdateOne) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpdateOne {
+ _u.mutation.SetMetadata(v)
+ return _u
+}
+
+// SetIdentity sets the "identity" edge to the AuthIdentity entity.
+func (_u *AuthIdentityChannelUpdateOne) SetIdentity(v *AuthIdentity) *AuthIdentityChannelUpdateOne {
+ return _u.SetIdentityID(v.ID)
+}
+
+// Mutation returns the AuthIdentityChannelMutation object of the builder.
+func (_u *AuthIdentityChannelUpdateOne) Mutation() *AuthIdentityChannelMutation {
+ return _u.mutation
+}
+
+// ClearIdentity clears the "identity" edge to the AuthIdentity entity.
+func (_u *AuthIdentityChannelUpdateOne) ClearIdentity() *AuthIdentityChannelUpdateOne {
+ _u.mutation.ClearIdentity()
+ return _u
+}
+
+// Where appends a list predicates to the AuthIdentityChannelUpdate builder.
+func (_u *AuthIdentityChannelUpdateOne) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelUpdateOne {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// Select allows selecting one or more fields (columns) of the returned entity.
+// The default is selecting all fields defined in the entity schema.
+func (_u *AuthIdentityChannelUpdateOne) Select(field string, fields ...string) *AuthIdentityChannelUpdateOne {
+ _u.fields = append([]string{field}, fields...)
+ return _u
+}
+
+// Save executes the query and returns the updated AuthIdentityChannel entity.
+func (_u *AuthIdentityChannelUpdateOne) Save(ctx context.Context) (*AuthIdentityChannel, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *AuthIdentityChannelUpdateOne) SaveX(ctx context.Context) *AuthIdentityChannel {
+ node, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// Exec executes the query on the entity.
+func (_u *AuthIdentityChannelUpdateOne) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *AuthIdentityChannelUpdateOne) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *AuthIdentityChannelUpdateOne) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := authidentitychannel.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *AuthIdentityChannelUpdateOne) check() error {
+ if v, ok := _u.mutation.ProviderType(); ok {
+ if err := authidentitychannel.ProviderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_type": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderKey(); ok {
+ if err := authidentitychannel.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_key": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Channel(); ok {
+ if err := authidentitychannel.ChannelValidator(v); err != nil {
+ return &ValidationError{Name: "channel", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ChannelAppID(); ok {
+ if err := authidentitychannel.ChannelAppIDValidator(v); err != nil {
+ return &ValidationError{Name: "channel_app_id", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_app_id": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ChannelSubject(); ok {
+ if err := authidentitychannel.ChannelSubjectValidator(v); err != nil {
+ return &ValidationError{Name: "channel_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_subject": %w`, err)}
+ }
+ }
+ if _u.mutation.IdentityCleared() && len(_u.mutation.IdentityIDs()) > 0 {
+ return errors.New(`ent: clearing a required unique edge "AuthIdentityChannel.identity"`)
+ }
+ return nil
+}
+
+func (_u *AuthIdentityChannelUpdateOne) sqlSave(ctx context.Context) (_node *AuthIdentityChannel, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(authidentitychannel.Table, authidentitychannel.Columns, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64))
+ id, ok := _u.mutation.ID()
+ if !ok {
+ return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "AuthIdentityChannel.id" for update`)}
+ }
+ _spec.Node.ID.Value = id
+ if fields := _u.fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, authidentitychannel.FieldID)
+ for _, f := range fields {
+ if !authidentitychannel.ValidColumn(f) {
+ return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ if f != authidentitychannel.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, f)
+ }
+ }
+ }
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(authidentitychannel.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.ProviderType(); ok {
+ _spec.SetField(authidentitychannel.FieldProviderType, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderKey(); ok {
+ _spec.SetField(authidentitychannel.FieldProviderKey, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Channel(); ok {
+ _spec.SetField(authidentitychannel.FieldChannel, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ChannelAppID(); ok {
+ _spec.SetField(authidentitychannel.FieldChannelAppID, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ChannelSubject(); ok {
+ _spec.SetField(authidentitychannel.FieldChannelSubject, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Metadata(); ok {
+ _spec.SetField(authidentitychannel.FieldMetadata, field.TypeJSON, value)
+ }
+ if _u.mutation.IdentityCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentitychannel.IdentityTable,
+ Columns: []string{authidentitychannel.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.IdentityIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentitychannel.IdentityTable,
+ Columns: []string{authidentitychannel.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ _node = &AuthIdentityChannel{config: _u.config}
+ _spec.Assign = _node.assignValues
+ _spec.ScanValues = _node.scanValues
+ if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{authidentitychannel.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
diff --git a/backend/ent/client.go b/backend/ent/client.go
index e52e015a..b02f519b 100644
--- a/backend/ent/client.go
+++ b/backend/ent/client.go
@@ -20,12 +20,16 @@ import (
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/proxy"
@@ -60,18 +64,26 @@ type Client struct {
Announcement *AnnouncementClient
// AnnouncementRead is the client for interacting with the AnnouncementRead builders.
AnnouncementRead *AnnouncementReadClient
+ // AuthIdentity is the client for interacting with the AuthIdentity builders.
+ AuthIdentity *AuthIdentityClient
+ // AuthIdentityChannel is the client for interacting with the AuthIdentityChannel builders.
+ AuthIdentityChannel *AuthIdentityChannelClient
// ErrorPassthroughRule is the client for interacting with the ErrorPassthroughRule builders.
ErrorPassthroughRule *ErrorPassthroughRuleClient
// Group is the client for interacting with the Group builders.
Group *GroupClient
// IdempotencyRecord is the client for interacting with the IdempotencyRecord builders.
IdempotencyRecord *IdempotencyRecordClient
+ // IdentityAdoptionDecision is the client for interacting with the IdentityAdoptionDecision builders.
+ IdentityAdoptionDecision *IdentityAdoptionDecisionClient
// PaymentAuditLog is the client for interacting with the PaymentAuditLog builders.
PaymentAuditLog *PaymentAuditLogClient
// PaymentOrder is the client for interacting with the PaymentOrder builders.
PaymentOrder *PaymentOrderClient
// PaymentProviderInstance is the client for interacting with the PaymentProviderInstance builders.
PaymentProviderInstance *PaymentProviderInstanceClient
+ // PendingAuthSession is the client for interacting with the PendingAuthSession builders.
+ PendingAuthSession *PendingAuthSessionClient
// PromoCode is the client for interacting with the PromoCode builders.
PromoCode *PromoCodeClient
// PromoCodeUsage is the client for interacting with the PromoCodeUsage builders.
@@ -118,12 +130,16 @@ func (c *Client) init() {
c.AccountGroup = NewAccountGroupClient(c.config)
c.Announcement = NewAnnouncementClient(c.config)
c.AnnouncementRead = NewAnnouncementReadClient(c.config)
+ c.AuthIdentity = NewAuthIdentityClient(c.config)
+ c.AuthIdentityChannel = NewAuthIdentityChannelClient(c.config)
c.ErrorPassthroughRule = NewErrorPassthroughRuleClient(c.config)
c.Group = NewGroupClient(c.config)
c.IdempotencyRecord = NewIdempotencyRecordClient(c.config)
+ c.IdentityAdoptionDecision = NewIdentityAdoptionDecisionClient(c.config)
c.PaymentAuditLog = NewPaymentAuditLogClient(c.config)
c.PaymentOrder = NewPaymentOrderClient(c.config)
c.PaymentProviderInstance = NewPaymentProviderInstanceClient(c.config)
+ c.PendingAuthSession = NewPendingAuthSessionClient(c.config)
c.PromoCode = NewPromoCodeClient(c.config)
c.PromoCodeUsage = NewPromoCodeUsageClient(c.config)
c.Proxy = NewProxyClient(c.config)
@@ -229,34 +245,38 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) {
cfg := c.config
cfg.driver = tx
return &Tx{
- ctx: ctx,
- config: cfg,
- APIKey: NewAPIKeyClient(cfg),
- Account: NewAccountClient(cfg),
- AccountGroup: NewAccountGroupClient(cfg),
- Announcement: NewAnnouncementClient(cfg),
- AnnouncementRead: NewAnnouncementReadClient(cfg),
- ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg),
- Group: NewGroupClient(cfg),
- IdempotencyRecord: NewIdempotencyRecordClient(cfg),
- PaymentAuditLog: NewPaymentAuditLogClient(cfg),
- PaymentOrder: NewPaymentOrderClient(cfg),
- PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg),
- PromoCode: NewPromoCodeClient(cfg),
- PromoCodeUsage: NewPromoCodeUsageClient(cfg),
- Proxy: NewProxyClient(cfg),
- RedeemCode: NewRedeemCodeClient(cfg),
- SecuritySecret: NewSecuritySecretClient(cfg),
- Setting: NewSettingClient(cfg),
- SubscriptionPlan: NewSubscriptionPlanClient(cfg),
- TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg),
- UsageCleanupTask: NewUsageCleanupTaskClient(cfg),
- UsageLog: NewUsageLogClient(cfg),
- User: NewUserClient(cfg),
- UserAllowedGroup: NewUserAllowedGroupClient(cfg),
- UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg),
- UserAttributeValue: NewUserAttributeValueClient(cfg),
- UserSubscription: NewUserSubscriptionClient(cfg),
+ ctx: ctx,
+ config: cfg,
+ APIKey: NewAPIKeyClient(cfg),
+ Account: NewAccountClient(cfg),
+ AccountGroup: NewAccountGroupClient(cfg),
+ Announcement: NewAnnouncementClient(cfg),
+ AnnouncementRead: NewAnnouncementReadClient(cfg),
+ AuthIdentity: NewAuthIdentityClient(cfg),
+ AuthIdentityChannel: NewAuthIdentityChannelClient(cfg),
+ ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg),
+ Group: NewGroupClient(cfg),
+ IdempotencyRecord: NewIdempotencyRecordClient(cfg),
+ IdentityAdoptionDecision: NewIdentityAdoptionDecisionClient(cfg),
+ PaymentAuditLog: NewPaymentAuditLogClient(cfg),
+ PaymentOrder: NewPaymentOrderClient(cfg),
+ PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg),
+ PendingAuthSession: NewPendingAuthSessionClient(cfg),
+ PromoCode: NewPromoCodeClient(cfg),
+ PromoCodeUsage: NewPromoCodeUsageClient(cfg),
+ Proxy: NewProxyClient(cfg),
+ RedeemCode: NewRedeemCodeClient(cfg),
+ SecuritySecret: NewSecuritySecretClient(cfg),
+ Setting: NewSettingClient(cfg),
+ SubscriptionPlan: NewSubscriptionPlanClient(cfg),
+ TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg),
+ UsageCleanupTask: NewUsageCleanupTaskClient(cfg),
+ UsageLog: NewUsageLogClient(cfg),
+ User: NewUserClient(cfg),
+ UserAllowedGroup: NewUserAllowedGroupClient(cfg),
+ UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg),
+ UserAttributeValue: NewUserAttributeValueClient(cfg),
+ UserSubscription: NewUserSubscriptionClient(cfg),
}, nil
}
@@ -274,34 +294,38 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
cfg := c.config
cfg.driver = &txDriver{tx: tx, drv: c.driver}
return &Tx{
- ctx: ctx,
- config: cfg,
- APIKey: NewAPIKeyClient(cfg),
- Account: NewAccountClient(cfg),
- AccountGroup: NewAccountGroupClient(cfg),
- Announcement: NewAnnouncementClient(cfg),
- AnnouncementRead: NewAnnouncementReadClient(cfg),
- ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg),
- Group: NewGroupClient(cfg),
- IdempotencyRecord: NewIdempotencyRecordClient(cfg),
- PaymentAuditLog: NewPaymentAuditLogClient(cfg),
- PaymentOrder: NewPaymentOrderClient(cfg),
- PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg),
- PromoCode: NewPromoCodeClient(cfg),
- PromoCodeUsage: NewPromoCodeUsageClient(cfg),
- Proxy: NewProxyClient(cfg),
- RedeemCode: NewRedeemCodeClient(cfg),
- SecuritySecret: NewSecuritySecretClient(cfg),
- Setting: NewSettingClient(cfg),
- SubscriptionPlan: NewSubscriptionPlanClient(cfg),
- TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg),
- UsageCleanupTask: NewUsageCleanupTaskClient(cfg),
- UsageLog: NewUsageLogClient(cfg),
- User: NewUserClient(cfg),
- UserAllowedGroup: NewUserAllowedGroupClient(cfg),
- UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg),
- UserAttributeValue: NewUserAttributeValueClient(cfg),
- UserSubscription: NewUserSubscriptionClient(cfg),
+ ctx: ctx,
+ config: cfg,
+ APIKey: NewAPIKeyClient(cfg),
+ Account: NewAccountClient(cfg),
+ AccountGroup: NewAccountGroupClient(cfg),
+ Announcement: NewAnnouncementClient(cfg),
+ AnnouncementRead: NewAnnouncementReadClient(cfg),
+ AuthIdentity: NewAuthIdentityClient(cfg),
+ AuthIdentityChannel: NewAuthIdentityChannelClient(cfg),
+ ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg),
+ Group: NewGroupClient(cfg),
+ IdempotencyRecord: NewIdempotencyRecordClient(cfg),
+ IdentityAdoptionDecision: NewIdentityAdoptionDecisionClient(cfg),
+ PaymentAuditLog: NewPaymentAuditLogClient(cfg),
+ PaymentOrder: NewPaymentOrderClient(cfg),
+ PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg),
+ PendingAuthSession: NewPendingAuthSessionClient(cfg),
+ PromoCode: NewPromoCodeClient(cfg),
+ PromoCodeUsage: NewPromoCodeUsageClient(cfg),
+ Proxy: NewProxyClient(cfg),
+ RedeemCode: NewRedeemCodeClient(cfg),
+ SecuritySecret: NewSecuritySecretClient(cfg),
+ Setting: NewSettingClient(cfg),
+ SubscriptionPlan: NewSubscriptionPlanClient(cfg),
+ TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg),
+ UsageCleanupTask: NewUsageCleanupTaskClient(cfg),
+ UsageLog: NewUsageLogClient(cfg),
+ User: NewUserClient(cfg),
+ UserAllowedGroup: NewUserAllowedGroupClient(cfg),
+ UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg),
+ UserAttributeValue: NewUserAttributeValueClient(cfg),
+ UserSubscription: NewUserSubscriptionClient(cfg),
}, nil
}
@@ -332,11 +356,12 @@ func (c *Client) Close() error {
func (c *Client) Use(hooks ...Hook) {
for _, n := range []interface{ Use(...Hook) }{
c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead,
- c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PaymentAuditLog,
- c.PaymentOrder, c.PaymentProviderInstance, c.PromoCode, c.PromoCodeUsage,
- c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, c.SubscriptionPlan,
- c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, c.User,
- c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
+ c.AuthIdentity, c.AuthIdentityChannel, c.ErrorPassthroughRule, c.Group,
+ c.IdempotencyRecord, c.IdentityAdoptionDecision, c.PaymentAuditLog,
+ c.PaymentOrder, c.PaymentProviderInstance, c.PendingAuthSession, c.PromoCode,
+ c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting,
+ c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog,
+ c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
c.UserSubscription,
} {
n.Use(hooks...)
@@ -348,11 +373,12 @@ func (c *Client) Use(hooks ...Hook) {
func (c *Client) Intercept(interceptors ...Interceptor) {
for _, n := range []interface{ Intercept(...Interceptor) }{
c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead,
- c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PaymentAuditLog,
- c.PaymentOrder, c.PaymentProviderInstance, c.PromoCode, c.PromoCodeUsage,
- c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, c.SubscriptionPlan,
- c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, c.User,
- c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
+ c.AuthIdentity, c.AuthIdentityChannel, c.ErrorPassthroughRule, c.Group,
+ c.IdempotencyRecord, c.IdentityAdoptionDecision, c.PaymentAuditLog,
+ c.PaymentOrder, c.PaymentProviderInstance, c.PendingAuthSession, c.PromoCode,
+ c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting,
+ c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog,
+ c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
c.UserSubscription,
} {
n.Intercept(interceptors...)
@@ -372,18 +398,26 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) {
return c.Announcement.mutate(ctx, m)
case *AnnouncementReadMutation:
return c.AnnouncementRead.mutate(ctx, m)
+ case *AuthIdentityMutation:
+ return c.AuthIdentity.mutate(ctx, m)
+ case *AuthIdentityChannelMutation:
+ return c.AuthIdentityChannel.mutate(ctx, m)
case *ErrorPassthroughRuleMutation:
return c.ErrorPassthroughRule.mutate(ctx, m)
case *GroupMutation:
return c.Group.mutate(ctx, m)
case *IdempotencyRecordMutation:
return c.IdempotencyRecord.mutate(ctx, m)
+ case *IdentityAdoptionDecisionMutation:
+ return c.IdentityAdoptionDecision.mutate(ctx, m)
case *PaymentAuditLogMutation:
return c.PaymentAuditLog.mutate(ctx, m)
case *PaymentOrderMutation:
return c.PaymentOrder.mutate(ctx, m)
case *PaymentProviderInstanceMutation:
return c.PaymentProviderInstance.mutate(ctx, m)
+ case *PendingAuthSessionMutation:
+ return c.PendingAuthSession.mutate(ctx, m)
case *PromoCodeMutation:
return c.PromoCode.mutate(ctx, m)
case *PromoCodeUsageMutation:
@@ -1231,6 +1265,336 @@ func (c *AnnouncementReadClient) mutate(ctx context.Context, m *AnnouncementRead
}
}
+// AuthIdentityClient is a client for the AuthIdentity schema.
+type AuthIdentityClient struct {
+ config
+}
+
+// NewAuthIdentityClient returns a client for the AuthIdentity from the given config.
+func NewAuthIdentityClient(c config) *AuthIdentityClient {
+ return &AuthIdentityClient{config: c}
+}
+
+// Use adds a list of mutation hooks to the hooks stack.
+// A call to `Use(f, g, h)` equals to `authidentity.Hooks(f(g(h())))`.
+func (c *AuthIdentityClient) Use(hooks ...Hook) {
+ c.hooks.AuthIdentity = append(c.hooks.AuthIdentity, hooks...)
+}
+
+// Intercept adds a list of query interceptors to the interceptors stack.
+// A call to `Intercept(f, g, h)` equals to `authidentity.Intercept(f(g(h())))`.
+func (c *AuthIdentityClient) Intercept(interceptors ...Interceptor) {
+ c.inters.AuthIdentity = append(c.inters.AuthIdentity, interceptors...)
+}
+
+// Create returns a builder for creating a AuthIdentity entity.
+func (c *AuthIdentityClient) Create() *AuthIdentityCreate {
+ mutation := newAuthIdentityMutation(c.config, OpCreate)
+ return &AuthIdentityCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// CreateBulk returns a builder for creating a bulk of AuthIdentity entities.
+func (c *AuthIdentityClient) CreateBulk(builders ...*AuthIdentityCreate) *AuthIdentityCreateBulk {
+ return &AuthIdentityCreateBulk{config: c.config, builders: builders}
+}
+
+// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
+// a builder and applies setFunc on it.
+func (c *AuthIdentityClient) MapCreateBulk(slice any, setFunc func(*AuthIdentityCreate, int)) *AuthIdentityCreateBulk {
+ rv := reflect.ValueOf(slice)
+ if rv.Kind() != reflect.Slice {
+ return &AuthIdentityCreateBulk{err: fmt.Errorf("calling to AuthIdentityClient.MapCreateBulk with wrong type %T, need slice", slice)}
+ }
+ builders := make([]*AuthIdentityCreate, rv.Len())
+ for i := 0; i < rv.Len(); i++ {
+ builders[i] = c.Create()
+ setFunc(builders[i], i)
+ }
+ return &AuthIdentityCreateBulk{config: c.config, builders: builders}
+}
+
+// Update returns an update builder for AuthIdentity.
+func (c *AuthIdentityClient) Update() *AuthIdentityUpdate {
+ mutation := newAuthIdentityMutation(c.config, OpUpdate)
+ return &AuthIdentityUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOne returns an update builder for the given entity.
+func (c *AuthIdentityClient) UpdateOne(_m *AuthIdentity) *AuthIdentityUpdateOne {
+ mutation := newAuthIdentityMutation(c.config, OpUpdateOne, withAuthIdentity(_m))
+ return &AuthIdentityUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOneID returns an update builder for the given id.
+func (c *AuthIdentityClient) UpdateOneID(id int64) *AuthIdentityUpdateOne {
+ mutation := newAuthIdentityMutation(c.config, OpUpdateOne, withAuthIdentityID(id))
+ return &AuthIdentityUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// Delete returns a delete builder for AuthIdentity.
+func (c *AuthIdentityClient) Delete() *AuthIdentityDelete {
+ mutation := newAuthIdentityMutation(c.config, OpDelete)
+ return &AuthIdentityDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// DeleteOne returns a builder for deleting the given entity.
+func (c *AuthIdentityClient) DeleteOne(_m *AuthIdentity) *AuthIdentityDeleteOne {
+ return c.DeleteOneID(_m.ID)
+}
+
+// DeleteOneID returns a builder for deleting the given entity by its id.
+func (c *AuthIdentityClient) DeleteOneID(id int64) *AuthIdentityDeleteOne {
+ builder := c.Delete().Where(authidentity.ID(id))
+ builder.mutation.id = &id
+ builder.mutation.op = OpDeleteOne
+ return &AuthIdentityDeleteOne{builder}
+}
+
+// Query returns a query builder for AuthIdentity.
+func (c *AuthIdentityClient) Query() *AuthIdentityQuery {
+ return &AuthIdentityQuery{
+ config: c.config,
+ ctx: &QueryContext{Type: TypeAuthIdentity},
+ inters: c.Interceptors(),
+ }
+}
+
+// Get returns a AuthIdentity entity by its id.
+func (c *AuthIdentityClient) Get(ctx context.Context, id int64) (*AuthIdentity, error) {
+ return c.Query().Where(authidentity.ID(id)).Only(ctx)
+}
+
+// GetX is like Get, but panics if an error occurs.
+func (c *AuthIdentityClient) GetX(ctx context.Context, id int64) *AuthIdentity {
+ obj, err := c.Get(ctx, id)
+ if err != nil {
+ panic(err)
+ }
+ return obj
+}
+
+// QueryUser queries the user edge of a AuthIdentity.
+func (c *AuthIdentityClient) QueryUser(_m *AuthIdentity) *UserQuery {
+ query := (&UserClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(authidentity.Table, authidentity.FieldID, id),
+ sqlgraph.To(user.Table, user.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, authidentity.UserTable, authidentity.UserColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// QueryChannels queries the channels edge of a AuthIdentity.
+func (c *AuthIdentityClient) QueryChannels(_m *AuthIdentity) *AuthIdentityChannelQuery {
+ query := (&AuthIdentityChannelClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(authidentity.Table, authidentity.FieldID, id),
+ sqlgraph.To(authidentitychannel.Table, authidentitychannel.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, authidentity.ChannelsTable, authidentity.ChannelsColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// QueryAdoptionDecisions queries the adoption_decisions edge of a AuthIdentity.
+func (c *AuthIdentityClient) QueryAdoptionDecisions(_m *AuthIdentity) *IdentityAdoptionDecisionQuery {
+ query := (&IdentityAdoptionDecisionClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(authidentity.Table, authidentity.FieldID, id),
+ sqlgraph.To(identityadoptiondecision.Table, identityadoptiondecision.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, authidentity.AdoptionDecisionsTable, authidentity.AdoptionDecisionsColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// Hooks returns the client hooks.
+func (c *AuthIdentityClient) Hooks() []Hook {
+ return c.hooks.AuthIdentity
+}
+
+// Interceptors returns the client interceptors.
+func (c *AuthIdentityClient) Interceptors() []Interceptor {
+ return c.inters.AuthIdentity
+}
+
+func (c *AuthIdentityClient) mutate(ctx context.Context, m *AuthIdentityMutation) (Value, error) {
+ switch m.Op() {
+ case OpCreate:
+ return (&AuthIdentityCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdate:
+ return (&AuthIdentityUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdateOne:
+ return (&AuthIdentityUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpDelete, OpDeleteOne:
+ return (&AuthIdentityDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
+ default:
+ return nil, fmt.Errorf("ent: unknown AuthIdentity mutation op: %q", m.Op())
+ }
+}
+
+// AuthIdentityChannelClient is a client for the AuthIdentityChannel schema.
+type AuthIdentityChannelClient struct {
+ config
+}
+
+// NewAuthIdentityChannelClient returns a client for the AuthIdentityChannel from the given config.
+func NewAuthIdentityChannelClient(c config) *AuthIdentityChannelClient {
+ return &AuthIdentityChannelClient{config: c}
+}
+
+// Use adds a list of mutation hooks to the hooks stack.
+// A call to `Use(f, g, h)` equals to `authidentitychannel.Hooks(f(g(h())))`.
+func (c *AuthIdentityChannelClient) Use(hooks ...Hook) {
+ c.hooks.AuthIdentityChannel = append(c.hooks.AuthIdentityChannel, hooks...)
+}
+
+// Intercept adds a list of query interceptors to the interceptors stack.
+// A call to `Intercept(f, g, h)` equals to `authidentitychannel.Intercept(f(g(h())))`.
+func (c *AuthIdentityChannelClient) Intercept(interceptors ...Interceptor) {
+ c.inters.AuthIdentityChannel = append(c.inters.AuthIdentityChannel, interceptors...)
+}
+
+// Create returns a builder for creating a AuthIdentityChannel entity.
+func (c *AuthIdentityChannelClient) Create() *AuthIdentityChannelCreate {
+ mutation := newAuthIdentityChannelMutation(c.config, OpCreate)
+ return &AuthIdentityChannelCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// CreateBulk returns a builder for creating a bulk of AuthIdentityChannel entities.
+func (c *AuthIdentityChannelClient) CreateBulk(builders ...*AuthIdentityChannelCreate) *AuthIdentityChannelCreateBulk {
+ return &AuthIdentityChannelCreateBulk{config: c.config, builders: builders}
+}
+
+// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
+// a builder and applies setFunc on it.
+func (c *AuthIdentityChannelClient) MapCreateBulk(slice any, setFunc func(*AuthIdentityChannelCreate, int)) *AuthIdentityChannelCreateBulk {
+ rv := reflect.ValueOf(slice)
+ if rv.Kind() != reflect.Slice {
+ return &AuthIdentityChannelCreateBulk{err: fmt.Errorf("calling to AuthIdentityChannelClient.MapCreateBulk with wrong type %T, need slice", slice)}
+ }
+ builders := make([]*AuthIdentityChannelCreate, rv.Len())
+ for i := 0; i < rv.Len(); i++ {
+ builders[i] = c.Create()
+ setFunc(builders[i], i)
+ }
+ return &AuthIdentityChannelCreateBulk{config: c.config, builders: builders}
+}
+
+// Update returns an update builder for AuthIdentityChannel.
+func (c *AuthIdentityChannelClient) Update() *AuthIdentityChannelUpdate {
+ mutation := newAuthIdentityChannelMutation(c.config, OpUpdate)
+ return &AuthIdentityChannelUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOne returns an update builder for the given entity.
+func (c *AuthIdentityChannelClient) UpdateOne(_m *AuthIdentityChannel) *AuthIdentityChannelUpdateOne {
+ mutation := newAuthIdentityChannelMutation(c.config, OpUpdateOne, withAuthIdentityChannel(_m))
+ return &AuthIdentityChannelUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOneID returns an update builder for the given id.
+func (c *AuthIdentityChannelClient) UpdateOneID(id int64) *AuthIdentityChannelUpdateOne {
+ mutation := newAuthIdentityChannelMutation(c.config, OpUpdateOne, withAuthIdentityChannelID(id))
+ return &AuthIdentityChannelUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// Delete returns a delete builder for AuthIdentityChannel.
+func (c *AuthIdentityChannelClient) Delete() *AuthIdentityChannelDelete {
+ mutation := newAuthIdentityChannelMutation(c.config, OpDelete)
+ return &AuthIdentityChannelDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// DeleteOne returns a builder for deleting the given entity.
+func (c *AuthIdentityChannelClient) DeleteOne(_m *AuthIdentityChannel) *AuthIdentityChannelDeleteOne {
+ return c.DeleteOneID(_m.ID)
+}
+
+// DeleteOneID returns a builder for deleting the given entity by its id.
+func (c *AuthIdentityChannelClient) DeleteOneID(id int64) *AuthIdentityChannelDeleteOne {
+ builder := c.Delete().Where(authidentitychannel.ID(id))
+ builder.mutation.id = &id
+ builder.mutation.op = OpDeleteOne
+ return &AuthIdentityChannelDeleteOne{builder}
+}
+
+// Query returns a query builder for AuthIdentityChannel.
+func (c *AuthIdentityChannelClient) Query() *AuthIdentityChannelQuery {
+ return &AuthIdentityChannelQuery{
+ config: c.config,
+ ctx: &QueryContext{Type: TypeAuthIdentityChannel},
+ inters: c.Interceptors(),
+ }
+}
+
+// Get returns a AuthIdentityChannel entity by its id.
+func (c *AuthIdentityChannelClient) Get(ctx context.Context, id int64) (*AuthIdentityChannel, error) {
+ return c.Query().Where(authidentitychannel.ID(id)).Only(ctx)
+}
+
+// GetX is like Get, but panics if an error occurs.
+func (c *AuthIdentityChannelClient) GetX(ctx context.Context, id int64) *AuthIdentityChannel {
+ obj, err := c.Get(ctx, id)
+ if err != nil {
+ panic(err)
+ }
+ return obj
+}
+
+// QueryIdentity queries the identity edge of a AuthIdentityChannel.
+func (c *AuthIdentityChannelClient) QueryIdentity(_m *AuthIdentityChannel) *AuthIdentityQuery {
+ query := (&AuthIdentityClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(authidentitychannel.Table, authidentitychannel.FieldID, id),
+ sqlgraph.To(authidentity.Table, authidentity.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, authidentitychannel.IdentityTable, authidentitychannel.IdentityColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// Hooks returns the client hooks.
+func (c *AuthIdentityChannelClient) Hooks() []Hook {
+ return c.hooks.AuthIdentityChannel
+}
+
+// Interceptors returns the client interceptors.
+func (c *AuthIdentityChannelClient) Interceptors() []Interceptor {
+ return c.inters.AuthIdentityChannel
+}
+
+func (c *AuthIdentityChannelClient) mutate(ctx context.Context, m *AuthIdentityChannelMutation) (Value, error) {
+ switch m.Op() {
+ case OpCreate:
+ return (&AuthIdentityChannelCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdate:
+ return (&AuthIdentityChannelUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdateOne:
+ return (&AuthIdentityChannelUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpDelete, OpDeleteOne:
+ return (&AuthIdentityChannelDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
+ default:
+ return nil, fmt.Errorf("ent: unknown AuthIdentityChannel mutation op: %q", m.Op())
+ }
+}
+
// ErrorPassthroughRuleClient is a client for the ErrorPassthroughRule schema.
type ErrorPassthroughRuleClient struct {
config
@@ -1760,6 +2124,171 @@ func (c *IdempotencyRecordClient) mutate(ctx context.Context, m *IdempotencyReco
}
}
+// IdentityAdoptionDecisionClient is a client for the IdentityAdoptionDecision schema.
+type IdentityAdoptionDecisionClient struct {
+ config
+}
+
+// NewIdentityAdoptionDecisionClient returns a client for the IdentityAdoptionDecision from the given config.
+func NewIdentityAdoptionDecisionClient(c config) *IdentityAdoptionDecisionClient {
+ return &IdentityAdoptionDecisionClient{config: c}
+}
+
+// Use adds a list of mutation hooks to the hooks stack.
+// A call to `Use(f, g, h)` equals to `identityadoptiondecision.Hooks(f(g(h())))`.
+func (c *IdentityAdoptionDecisionClient) Use(hooks ...Hook) {
+ c.hooks.IdentityAdoptionDecision = append(c.hooks.IdentityAdoptionDecision, hooks...)
+}
+
+// Intercept adds a list of query interceptors to the interceptors stack.
+// A call to `Intercept(f, g, h)` equals to `identityadoptiondecision.Intercept(f(g(h())))`.
+func (c *IdentityAdoptionDecisionClient) Intercept(interceptors ...Interceptor) {
+ c.inters.IdentityAdoptionDecision = append(c.inters.IdentityAdoptionDecision, interceptors...)
+}
+
+// Create returns a builder for creating a IdentityAdoptionDecision entity.
+func (c *IdentityAdoptionDecisionClient) Create() *IdentityAdoptionDecisionCreate {
+ mutation := newIdentityAdoptionDecisionMutation(c.config, OpCreate)
+ return &IdentityAdoptionDecisionCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// CreateBulk returns a builder for creating a bulk of IdentityAdoptionDecision entities.
+func (c *IdentityAdoptionDecisionClient) CreateBulk(builders ...*IdentityAdoptionDecisionCreate) *IdentityAdoptionDecisionCreateBulk {
+ return &IdentityAdoptionDecisionCreateBulk{config: c.config, builders: builders}
+}
+
+// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
+// a builder and applies setFunc on it.
+func (c *IdentityAdoptionDecisionClient) MapCreateBulk(slice any, setFunc func(*IdentityAdoptionDecisionCreate, int)) *IdentityAdoptionDecisionCreateBulk {
+ rv := reflect.ValueOf(slice)
+ if rv.Kind() != reflect.Slice {
+ return &IdentityAdoptionDecisionCreateBulk{err: fmt.Errorf("calling to IdentityAdoptionDecisionClient.MapCreateBulk with wrong type %T, need slice", slice)}
+ }
+ builders := make([]*IdentityAdoptionDecisionCreate, rv.Len())
+ for i := 0; i < rv.Len(); i++ {
+ builders[i] = c.Create()
+ setFunc(builders[i], i)
+ }
+ return &IdentityAdoptionDecisionCreateBulk{config: c.config, builders: builders}
+}
+
+// Update returns an update builder for IdentityAdoptionDecision.
+func (c *IdentityAdoptionDecisionClient) Update() *IdentityAdoptionDecisionUpdate {
+ mutation := newIdentityAdoptionDecisionMutation(c.config, OpUpdate)
+ return &IdentityAdoptionDecisionUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOne returns an update builder for the given entity.
+func (c *IdentityAdoptionDecisionClient) UpdateOne(_m *IdentityAdoptionDecision) *IdentityAdoptionDecisionUpdateOne {
+ mutation := newIdentityAdoptionDecisionMutation(c.config, OpUpdateOne, withIdentityAdoptionDecision(_m))
+ return &IdentityAdoptionDecisionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOneID returns an update builder for the given id.
+func (c *IdentityAdoptionDecisionClient) UpdateOneID(id int64) *IdentityAdoptionDecisionUpdateOne {
+ mutation := newIdentityAdoptionDecisionMutation(c.config, OpUpdateOne, withIdentityAdoptionDecisionID(id))
+ return &IdentityAdoptionDecisionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// Delete returns a delete builder for IdentityAdoptionDecision.
+func (c *IdentityAdoptionDecisionClient) Delete() *IdentityAdoptionDecisionDelete {
+ mutation := newIdentityAdoptionDecisionMutation(c.config, OpDelete)
+ return &IdentityAdoptionDecisionDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// DeleteOne returns a builder for deleting the given entity.
+func (c *IdentityAdoptionDecisionClient) DeleteOne(_m *IdentityAdoptionDecision) *IdentityAdoptionDecisionDeleteOne {
+ return c.DeleteOneID(_m.ID)
+}
+
+// DeleteOneID returns a builder for deleting the given entity by its id.
+func (c *IdentityAdoptionDecisionClient) DeleteOneID(id int64) *IdentityAdoptionDecisionDeleteOne {
+ builder := c.Delete().Where(identityadoptiondecision.ID(id))
+ builder.mutation.id = &id
+ builder.mutation.op = OpDeleteOne
+ return &IdentityAdoptionDecisionDeleteOne{builder}
+}
+
+// Query returns a query builder for IdentityAdoptionDecision.
+func (c *IdentityAdoptionDecisionClient) Query() *IdentityAdoptionDecisionQuery {
+ return &IdentityAdoptionDecisionQuery{
+ config: c.config,
+ ctx: &QueryContext{Type: TypeIdentityAdoptionDecision},
+ inters: c.Interceptors(),
+ }
+}
+
+// Get returns a IdentityAdoptionDecision entity by its id.
+func (c *IdentityAdoptionDecisionClient) Get(ctx context.Context, id int64) (*IdentityAdoptionDecision, error) {
+ return c.Query().Where(identityadoptiondecision.ID(id)).Only(ctx)
+}
+
+// GetX is like Get, but panics if an error occurs.
+func (c *IdentityAdoptionDecisionClient) GetX(ctx context.Context, id int64) *IdentityAdoptionDecision {
+ obj, err := c.Get(ctx, id)
+ if err != nil {
+ panic(err)
+ }
+ return obj
+}
+
+// QueryPendingAuthSession queries the pending_auth_session edge of a IdentityAdoptionDecision.
+func (c *IdentityAdoptionDecisionClient) QueryPendingAuthSession(_m *IdentityAdoptionDecision) *PendingAuthSessionQuery {
+ query := (&PendingAuthSessionClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(identityadoptiondecision.Table, identityadoptiondecision.FieldID, id),
+ sqlgraph.To(pendingauthsession.Table, pendingauthsession.FieldID),
+ sqlgraph.Edge(sqlgraph.O2O, true, identityadoptiondecision.PendingAuthSessionTable, identityadoptiondecision.PendingAuthSessionColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// QueryIdentity queries the identity edge of a IdentityAdoptionDecision.
+func (c *IdentityAdoptionDecisionClient) QueryIdentity(_m *IdentityAdoptionDecision) *AuthIdentityQuery {
+ query := (&AuthIdentityClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(identityadoptiondecision.Table, identityadoptiondecision.FieldID, id),
+ sqlgraph.To(authidentity.Table, authidentity.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, identityadoptiondecision.IdentityTable, identityadoptiondecision.IdentityColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// Hooks returns the client hooks.
+func (c *IdentityAdoptionDecisionClient) Hooks() []Hook {
+ return c.hooks.IdentityAdoptionDecision
+}
+
+// Interceptors returns the client interceptors.
+func (c *IdentityAdoptionDecisionClient) Interceptors() []Interceptor {
+ return c.inters.IdentityAdoptionDecision
+}
+
+func (c *IdentityAdoptionDecisionClient) mutate(ctx context.Context, m *IdentityAdoptionDecisionMutation) (Value, error) {
+ switch m.Op() {
+ case OpCreate:
+ return (&IdentityAdoptionDecisionCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdate:
+ return (&IdentityAdoptionDecisionUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdateOne:
+ return (&IdentityAdoptionDecisionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpDelete, OpDeleteOne:
+ return (&IdentityAdoptionDecisionDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
+ default:
+ return nil, fmt.Errorf("ent: unknown IdentityAdoptionDecision mutation op: %q", m.Op())
+ }
+}
+
// PaymentAuditLogClient is a client for the PaymentAuditLog schema.
type PaymentAuditLogClient struct {
config
@@ -2175,6 +2704,171 @@ func (c *PaymentProviderInstanceClient) mutate(ctx context.Context, m *PaymentPr
}
}
+// PendingAuthSessionClient is a client for the PendingAuthSession schema.
+type PendingAuthSessionClient struct {
+ config
+}
+
+// NewPendingAuthSessionClient returns a client for the PendingAuthSession from the given config.
+func NewPendingAuthSessionClient(c config) *PendingAuthSessionClient {
+ return &PendingAuthSessionClient{config: c}
+}
+
+// Use adds a list of mutation hooks to the hooks stack.
+// A call to `Use(f, g, h)` equals to `pendingauthsession.Hooks(f(g(h())))`.
+func (c *PendingAuthSessionClient) Use(hooks ...Hook) {
+ c.hooks.PendingAuthSession = append(c.hooks.PendingAuthSession, hooks...)
+}
+
+// Intercept adds a list of query interceptors to the interceptors stack.
+// A call to `Intercept(f, g, h)` equals to `pendingauthsession.Intercept(f(g(h())))`.
+func (c *PendingAuthSessionClient) Intercept(interceptors ...Interceptor) {
+ c.inters.PendingAuthSession = append(c.inters.PendingAuthSession, interceptors...)
+}
+
+// Create returns a builder for creating a PendingAuthSession entity.
+func (c *PendingAuthSessionClient) Create() *PendingAuthSessionCreate {
+ mutation := newPendingAuthSessionMutation(c.config, OpCreate)
+ return &PendingAuthSessionCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// CreateBulk returns a builder for creating a bulk of PendingAuthSession entities.
+func (c *PendingAuthSessionClient) CreateBulk(builders ...*PendingAuthSessionCreate) *PendingAuthSessionCreateBulk {
+ return &PendingAuthSessionCreateBulk{config: c.config, builders: builders}
+}
+
+// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
+// a builder and applies setFunc on it.
+func (c *PendingAuthSessionClient) MapCreateBulk(slice any, setFunc func(*PendingAuthSessionCreate, int)) *PendingAuthSessionCreateBulk {
+ rv := reflect.ValueOf(slice)
+ if rv.Kind() != reflect.Slice {
+ return &PendingAuthSessionCreateBulk{err: fmt.Errorf("calling to PendingAuthSessionClient.MapCreateBulk with wrong type %T, need slice", slice)}
+ }
+ builders := make([]*PendingAuthSessionCreate, rv.Len())
+ for i := 0; i < rv.Len(); i++ {
+ builders[i] = c.Create()
+ setFunc(builders[i], i)
+ }
+ return &PendingAuthSessionCreateBulk{config: c.config, builders: builders}
+}
+
+// Update returns an update builder for PendingAuthSession.
+func (c *PendingAuthSessionClient) Update() *PendingAuthSessionUpdate {
+ mutation := newPendingAuthSessionMutation(c.config, OpUpdate)
+ return &PendingAuthSessionUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOne returns an update builder for the given entity.
+func (c *PendingAuthSessionClient) UpdateOne(_m *PendingAuthSession) *PendingAuthSessionUpdateOne {
+ mutation := newPendingAuthSessionMutation(c.config, OpUpdateOne, withPendingAuthSession(_m))
+ return &PendingAuthSessionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOneID returns an update builder for the given id.
+func (c *PendingAuthSessionClient) UpdateOneID(id int64) *PendingAuthSessionUpdateOne {
+ mutation := newPendingAuthSessionMutation(c.config, OpUpdateOne, withPendingAuthSessionID(id))
+ return &PendingAuthSessionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// Delete returns a delete builder for PendingAuthSession.
+func (c *PendingAuthSessionClient) Delete() *PendingAuthSessionDelete {
+ mutation := newPendingAuthSessionMutation(c.config, OpDelete)
+ return &PendingAuthSessionDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// DeleteOne returns a builder for deleting the given entity.
+func (c *PendingAuthSessionClient) DeleteOne(_m *PendingAuthSession) *PendingAuthSessionDeleteOne {
+ return c.DeleteOneID(_m.ID)
+}
+
+// DeleteOneID returns a builder for deleting the given entity by its id.
+func (c *PendingAuthSessionClient) DeleteOneID(id int64) *PendingAuthSessionDeleteOne {
+ builder := c.Delete().Where(pendingauthsession.ID(id))
+ builder.mutation.id = &id
+ builder.mutation.op = OpDeleteOne
+ return &PendingAuthSessionDeleteOne{builder}
+}
+
+// Query returns a query builder for PendingAuthSession.
+func (c *PendingAuthSessionClient) Query() *PendingAuthSessionQuery {
+ return &PendingAuthSessionQuery{
+ config: c.config,
+ ctx: &QueryContext{Type: TypePendingAuthSession},
+ inters: c.Interceptors(),
+ }
+}
+
+// Get returns a PendingAuthSession entity by its id.
+func (c *PendingAuthSessionClient) Get(ctx context.Context, id int64) (*PendingAuthSession, error) {
+ return c.Query().Where(pendingauthsession.ID(id)).Only(ctx)
+}
+
+// GetX is like Get, but panics if an error occurs.
+func (c *PendingAuthSessionClient) GetX(ctx context.Context, id int64) *PendingAuthSession {
+ obj, err := c.Get(ctx, id)
+ if err != nil {
+ panic(err)
+ }
+ return obj
+}
+
+// QueryTargetUser queries the target_user edge of a PendingAuthSession.
+func (c *PendingAuthSessionClient) QueryTargetUser(_m *PendingAuthSession) *UserQuery {
+ query := (&UserClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(pendingauthsession.Table, pendingauthsession.FieldID, id),
+ sqlgraph.To(user.Table, user.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, pendingauthsession.TargetUserTable, pendingauthsession.TargetUserColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// QueryAdoptionDecision queries the adoption_decision edge of a PendingAuthSession.
+func (c *PendingAuthSessionClient) QueryAdoptionDecision(_m *PendingAuthSession) *IdentityAdoptionDecisionQuery {
+ query := (&IdentityAdoptionDecisionClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(pendingauthsession.Table, pendingauthsession.FieldID, id),
+ sqlgraph.To(identityadoptiondecision.Table, identityadoptiondecision.FieldID),
+ sqlgraph.Edge(sqlgraph.O2O, false, pendingauthsession.AdoptionDecisionTable, pendingauthsession.AdoptionDecisionColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// Hooks returns the client hooks.
+func (c *PendingAuthSessionClient) Hooks() []Hook {
+ return c.hooks.PendingAuthSession
+}
+
+// Interceptors returns the client interceptors.
+func (c *PendingAuthSessionClient) Interceptors() []Interceptor {
+ return c.inters.PendingAuthSession
+}
+
+func (c *PendingAuthSessionClient) mutate(ctx context.Context, m *PendingAuthSessionMutation) (Value, error) {
+ switch m.Op() {
+ case OpCreate:
+ return (&PendingAuthSessionCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdate:
+ return (&PendingAuthSessionUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdateOne:
+ return (&PendingAuthSessionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpDelete, OpDeleteOne:
+ return (&PendingAuthSessionDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
+ default:
+ return nil, fmt.Errorf("ent: unknown PendingAuthSession mutation op: %q", m.Op())
+ }
+}
+
// PromoCodeClient is a client for the PromoCode schema.
type PromoCodeClient struct {
config
@@ -3951,6 +4645,38 @@ func (c *UserClient) QueryPaymentOrders(_m *User) *PaymentOrderQuery {
return query
}
+// QueryAuthIdentities queries the auth_identities edge of a User.
+func (c *UserClient) QueryAuthIdentities(_m *User) *AuthIdentityQuery {
+ query := (&AuthIdentityClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(user.Table, user.FieldID, id),
+ sqlgraph.To(authidentity.Table, authidentity.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, user.AuthIdentitiesTable, user.AuthIdentitiesColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// QueryPendingAuthSessions queries the pending_auth_sessions edge of a User.
+func (c *UserClient) QueryPendingAuthSessions(_m *User) *PendingAuthSessionQuery {
+ query := (&PendingAuthSessionClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(user.Table, user.FieldID, id),
+ sqlgraph.To(pendingauthsession.Table, pendingauthsession.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, user.PendingAuthSessionsTable, user.PendingAuthSessionsColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
// QueryUserAllowedGroups queries the user_allowed_groups edge of a User.
func (c *UserClient) QueryUserAllowedGroups(_m *User) *UserAllowedGroupQuery {
query := (&UserAllowedGroupClient{config: c.config}).Query()
@@ -4628,18 +5354,20 @@ func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscription
// hooks and interceptors per client, for fast access.
type (
hooks struct {
- APIKey, Account, AccountGroup, Announcement, AnnouncementRead,
- ErrorPassthroughRule, Group, IdempotencyRecord, PaymentAuditLog, PaymentOrder,
- PaymentProviderInstance, PromoCode, PromoCodeUsage, Proxy, RedeemCode,
- SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile,
+ APIKey, Account, AccountGroup, Announcement, AnnouncementRead, AuthIdentity,
+ AuthIdentityChannel, ErrorPassthroughRule, Group, IdempotencyRecord,
+ IdentityAdoptionDecision, PaymentAuditLog, PaymentOrder,
+ PaymentProviderInstance, PendingAuthSession, PromoCode, PromoCodeUsage, Proxy,
+ RedeemCode, SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile,
UsageCleanupTask, UsageLog, User, UserAllowedGroup, UserAttributeDefinition,
UserAttributeValue, UserSubscription []ent.Hook
}
inters struct {
- APIKey, Account, AccountGroup, Announcement, AnnouncementRead,
- ErrorPassthroughRule, Group, IdempotencyRecord, PaymentAuditLog, PaymentOrder,
- PaymentProviderInstance, PromoCode, PromoCodeUsage, Proxy, RedeemCode,
- SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile,
+ APIKey, Account, AccountGroup, Announcement, AnnouncementRead, AuthIdentity,
+ AuthIdentityChannel, ErrorPassthroughRule, Group, IdempotencyRecord,
+ IdentityAdoptionDecision, PaymentAuditLog, PaymentOrder,
+ PaymentProviderInstance, PendingAuthSession, PromoCode, PromoCodeUsage, Proxy,
+ RedeemCode, SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile,
UsageCleanupTask, UsageLog, User, UserAllowedGroup, UserAttributeDefinition,
UserAttributeValue, UserSubscription []ent.Interceptor
}
diff --git a/backend/ent/ent.go b/backend/ent/ent.go
index 96ed5e03..339e5369 100644
--- a/backend/ent/ent.go
+++ b/backend/ent/ent.go
@@ -17,12 +17,16 @@ import (
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/proxy"
@@ -98,32 +102,36 @@ var (
func checkColumn(t, c string) error {
initCheck.Do(func() {
columnCheck = sql.NewColumnCheck(map[string]func(string) bool{
- apikey.Table: apikey.ValidColumn,
- account.Table: account.ValidColumn,
- accountgroup.Table: accountgroup.ValidColumn,
- announcement.Table: announcement.ValidColumn,
- announcementread.Table: announcementread.ValidColumn,
- errorpassthroughrule.Table: errorpassthroughrule.ValidColumn,
- group.Table: group.ValidColumn,
- idempotencyrecord.Table: idempotencyrecord.ValidColumn,
- paymentauditlog.Table: paymentauditlog.ValidColumn,
- paymentorder.Table: paymentorder.ValidColumn,
- paymentproviderinstance.Table: paymentproviderinstance.ValidColumn,
- promocode.Table: promocode.ValidColumn,
- promocodeusage.Table: promocodeusage.ValidColumn,
- proxy.Table: proxy.ValidColumn,
- redeemcode.Table: redeemcode.ValidColumn,
- securitysecret.Table: securitysecret.ValidColumn,
- setting.Table: setting.ValidColumn,
- subscriptionplan.Table: subscriptionplan.ValidColumn,
- tlsfingerprintprofile.Table: tlsfingerprintprofile.ValidColumn,
- usagecleanuptask.Table: usagecleanuptask.ValidColumn,
- usagelog.Table: usagelog.ValidColumn,
- user.Table: user.ValidColumn,
- userallowedgroup.Table: userallowedgroup.ValidColumn,
- userattributedefinition.Table: userattributedefinition.ValidColumn,
- userattributevalue.Table: userattributevalue.ValidColumn,
- usersubscription.Table: usersubscription.ValidColumn,
+ apikey.Table: apikey.ValidColumn,
+ account.Table: account.ValidColumn,
+ accountgroup.Table: accountgroup.ValidColumn,
+ announcement.Table: announcement.ValidColumn,
+ announcementread.Table: announcementread.ValidColumn,
+ authidentity.Table: authidentity.ValidColumn,
+ authidentitychannel.Table: authidentitychannel.ValidColumn,
+ errorpassthroughrule.Table: errorpassthroughrule.ValidColumn,
+ group.Table: group.ValidColumn,
+ idempotencyrecord.Table: idempotencyrecord.ValidColumn,
+ identityadoptiondecision.Table: identityadoptiondecision.ValidColumn,
+ paymentauditlog.Table: paymentauditlog.ValidColumn,
+ paymentorder.Table: paymentorder.ValidColumn,
+ paymentproviderinstance.Table: paymentproviderinstance.ValidColumn,
+ pendingauthsession.Table: pendingauthsession.ValidColumn,
+ promocode.Table: promocode.ValidColumn,
+ promocodeusage.Table: promocodeusage.ValidColumn,
+ proxy.Table: proxy.ValidColumn,
+ redeemcode.Table: redeemcode.ValidColumn,
+ securitysecret.Table: securitysecret.ValidColumn,
+ setting.Table: setting.ValidColumn,
+ subscriptionplan.Table: subscriptionplan.ValidColumn,
+ tlsfingerprintprofile.Table: tlsfingerprintprofile.ValidColumn,
+ usagecleanuptask.Table: usagecleanuptask.ValidColumn,
+ usagelog.Table: usagelog.ValidColumn,
+ user.Table: user.ValidColumn,
+ userallowedgroup.Table: userallowedgroup.ValidColumn,
+ userattributedefinition.Table: userattributedefinition.ValidColumn,
+ userattributevalue.Table: userattributevalue.ValidColumn,
+ usersubscription.Table: usersubscription.ValidColumn,
})
})
return columnCheck(t, c)
diff --git a/backend/ent/hook/hook.go b/backend/ent/hook/hook.go
index 199dacea..46ac02bc 100644
--- a/backend/ent/hook/hook.go
+++ b/backend/ent/hook/hook.go
@@ -69,6 +69,30 @@ func (f AnnouncementReadFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.V
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AnnouncementReadMutation", m)
}
+// The AuthIdentityFunc type is an adapter to allow the use of ordinary
+// function as AuthIdentity mutator.
+type AuthIdentityFunc func(context.Context, *ent.AuthIdentityMutation) (ent.Value, error)
+
+// Mutate calls f(ctx, m).
+func (f AuthIdentityFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
+ if mv, ok := m.(*ent.AuthIdentityMutation); ok {
+ return f(ctx, mv)
+ }
+ return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AuthIdentityMutation", m)
+}
+
+// The AuthIdentityChannelFunc type is an adapter to allow the use of ordinary
+// function as AuthIdentityChannel mutator.
+type AuthIdentityChannelFunc func(context.Context, *ent.AuthIdentityChannelMutation) (ent.Value, error)
+
+// Mutate calls f(ctx, m).
+func (f AuthIdentityChannelFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
+ if mv, ok := m.(*ent.AuthIdentityChannelMutation); ok {
+ return f(ctx, mv)
+ }
+ return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AuthIdentityChannelMutation", m)
+}
+
// The ErrorPassthroughRuleFunc type is an adapter to allow the use of ordinary
// function as ErrorPassthroughRule mutator.
type ErrorPassthroughRuleFunc func(context.Context, *ent.ErrorPassthroughRuleMutation) (ent.Value, error)
@@ -105,6 +129,18 @@ func (f IdempotencyRecordFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.IdempotencyRecordMutation", m)
}
+// The IdentityAdoptionDecisionFunc type is an adapter to allow the use of ordinary
+// function as IdentityAdoptionDecision mutator.
+type IdentityAdoptionDecisionFunc func(context.Context, *ent.IdentityAdoptionDecisionMutation) (ent.Value, error)
+
+// Mutate calls f(ctx, m).
+func (f IdentityAdoptionDecisionFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
+ if mv, ok := m.(*ent.IdentityAdoptionDecisionMutation); ok {
+ return f(ctx, mv)
+ }
+ return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.IdentityAdoptionDecisionMutation", m)
+}
+
// The PaymentAuditLogFunc type is an adapter to allow the use of ordinary
// function as PaymentAuditLog mutator.
type PaymentAuditLogFunc func(context.Context, *ent.PaymentAuditLogMutation) (ent.Value, error)
@@ -141,6 +177,18 @@ func (f PaymentProviderInstanceFunc) Mutate(ctx context.Context, m ent.Mutation)
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.PaymentProviderInstanceMutation", m)
}
+// The PendingAuthSessionFunc type is an adapter to allow the use of ordinary
+// function as PendingAuthSession mutator.
+type PendingAuthSessionFunc func(context.Context, *ent.PendingAuthSessionMutation) (ent.Value, error)
+
+// Mutate calls f(ctx, m).
+func (f PendingAuthSessionFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
+ if mv, ok := m.(*ent.PendingAuthSessionMutation); ok {
+ return f(ctx, mv)
+ }
+ return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.PendingAuthSessionMutation", m)
+}
+
// The PromoCodeFunc type is an adapter to allow the use of ordinary
// function as PromoCode mutator.
type PromoCodeFunc func(context.Context, *ent.PromoCodeMutation) (ent.Value, error)
diff --git a/backend/ent/identityadoptiondecision.go b/backend/ent/identityadoptiondecision.go
new file mode 100644
index 00000000..ecaee65c
--- /dev/null
+++ b/backend/ent/identityadoptiondecision.go
@@ -0,0 +1,223 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "fmt"
+ "strings"
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/sql"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+)
+
+// IdentityAdoptionDecision is the model entity for the IdentityAdoptionDecision schema.
+type IdentityAdoptionDecision struct {
+ config `json:"-"`
+ // ID of the ent.
+ ID int64 `json:"id,omitempty"`
+ // CreatedAt holds the value of the "created_at" field.
+ CreatedAt time.Time `json:"created_at,omitempty"`
+ // UpdatedAt holds the value of the "updated_at" field.
+ UpdatedAt time.Time `json:"updated_at,omitempty"`
+ // PendingAuthSessionID holds the value of the "pending_auth_session_id" field.
+ PendingAuthSessionID int64 `json:"pending_auth_session_id,omitempty"`
+ // IdentityID holds the value of the "identity_id" field.
+ IdentityID *int64 `json:"identity_id,omitempty"`
+ // AdoptDisplayName holds the value of the "adopt_display_name" field.
+ AdoptDisplayName bool `json:"adopt_display_name,omitempty"`
+ // AdoptAvatar holds the value of the "adopt_avatar" field.
+ AdoptAvatar bool `json:"adopt_avatar,omitempty"`
+ // DecidedAt holds the value of the "decided_at" field.
+ DecidedAt time.Time `json:"decided_at,omitempty"`
+ // Edges holds the relations/edges for other nodes in the graph.
+ // The values are being populated by the IdentityAdoptionDecisionQuery when eager-loading is set.
+ Edges IdentityAdoptionDecisionEdges `json:"edges"`
+ selectValues sql.SelectValues
+}
+
+// IdentityAdoptionDecisionEdges holds the relations/edges for other nodes in the graph.
+type IdentityAdoptionDecisionEdges struct {
+ // PendingAuthSession holds the value of the pending_auth_session edge.
+ PendingAuthSession *PendingAuthSession `json:"pending_auth_session,omitempty"`
+ // Identity holds the value of the identity edge.
+ Identity *AuthIdentity `json:"identity,omitempty"`
+ // loadedTypes holds the information for reporting if a
+ // type was loaded (or requested) in eager-loading or not.
+ loadedTypes [2]bool
+}
+
+// PendingAuthSessionOrErr returns the PendingAuthSession value or an error if the edge
+// was not loaded in eager-loading, or loaded but was not found.
+func (e IdentityAdoptionDecisionEdges) PendingAuthSessionOrErr() (*PendingAuthSession, error) {
+ if e.PendingAuthSession != nil {
+ return e.PendingAuthSession, nil
+ } else if e.loadedTypes[0] {
+ return nil, &NotFoundError{label: pendingauthsession.Label}
+ }
+ return nil, &NotLoadedError{edge: "pending_auth_session"}
+}
+
+// IdentityOrErr returns the Identity value or an error if the edge
+// was not loaded in eager-loading, or loaded but was not found.
+func (e IdentityAdoptionDecisionEdges) IdentityOrErr() (*AuthIdentity, error) {
+ if e.Identity != nil {
+ return e.Identity, nil
+ } else if e.loadedTypes[1] {
+ return nil, &NotFoundError{label: authidentity.Label}
+ }
+ return nil, &NotLoadedError{edge: "identity"}
+}
+
+// scanValues returns the types for scanning values from sql.Rows.
+func (*IdentityAdoptionDecision) scanValues(columns []string) ([]any, error) {
+ values := make([]any, len(columns))
+ for i := range columns {
+ switch columns[i] {
+ case identityadoptiondecision.FieldAdoptDisplayName, identityadoptiondecision.FieldAdoptAvatar:
+ values[i] = new(sql.NullBool)
+ case identityadoptiondecision.FieldID, identityadoptiondecision.FieldPendingAuthSessionID, identityadoptiondecision.FieldIdentityID:
+ values[i] = new(sql.NullInt64)
+ case identityadoptiondecision.FieldCreatedAt, identityadoptiondecision.FieldUpdatedAt, identityadoptiondecision.FieldDecidedAt:
+ values[i] = new(sql.NullTime)
+ default:
+ values[i] = new(sql.UnknownType)
+ }
+ }
+ return values, nil
+}
+
+// assignValues assigns the values that were returned from sql.Rows (after scanning)
+// to the IdentityAdoptionDecision fields.
+func (_m *IdentityAdoptionDecision) assignValues(columns []string, values []any) error {
+ if m, n := len(values), len(columns); m < n {
+ return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
+ }
+ for i := range columns {
+ switch columns[i] {
+ case identityadoptiondecision.FieldID:
+ value, ok := values[i].(*sql.NullInt64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field id", value)
+ }
+ _m.ID = int64(value.Int64)
+ case identityadoptiondecision.FieldCreatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field created_at", values[i])
+ } else if value.Valid {
+ _m.CreatedAt = value.Time
+ }
+ case identityadoptiondecision.FieldUpdatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field updated_at", values[i])
+ } else if value.Valid {
+ _m.UpdatedAt = value.Time
+ }
+ case identityadoptiondecision.FieldPendingAuthSessionID:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field pending_auth_session_id", values[i])
+ } else if value.Valid {
+ _m.PendingAuthSessionID = value.Int64
+ }
+ case identityadoptiondecision.FieldIdentityID:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field identity_id", values[i])
+ } else if value.Valid {
+ _m.IdentityID = new(int64)
+ *_m.IdentityID = value.Int64
+ }
+ case identityadoptiondecision.FieldAdoptDisplayName:
+ if value, ok := values[i].(*sql.NullBool); !ok {
+ return fmt.Errorf("unexpected type %T for field adopt_display_name", values[i])
+ } else if value.Valid {
+ _m.AdoptDisplayName = value.Bool
+ }
+ case identityadoptiondecision.FieldAdoptAvatar:
+ if value, ok := values[i].(*sql.NullBool); !ok {
+ return fmt.Errorf("unexpected type %T for field adopt_avatar", values[i])
+ } else if value.Valid {
+ _m.AdoptAvatar = value.Bool
+ }
+ case identityadoptiondecision.FieldDecidedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field decided_at", values[i])
+ } else if value.Valid {
+ _m.DecidedAt = value.Time
+ }
+ default:
+ _m.selectValues.Set(columns[i], values[i])
+ }
+ }
+ return nil
+}
+
+// Value returns the ent.Value that was dynamically selected and assigned to the IdentityAdoptionDecision.
+// This includes values selected through modifiers, order, etc.
+func (_m *IdentityAdoptionDecision) Value(name string) (ent.Value, error) {
+ return _m.selectValues.Get(name)
+}
+
+// QueryPendingAuthSession queries the "pending_auth_session" edge of the IdentityAdoptionDecision entity.
+func (_m *IdentityAdoptionDecision) QueryPendingAuthSession() *PendingAuthSessionQuery {
+ return NewIdentityAdoptionDecisionClient(_m.config).QueryPendingAuthSession(_m)
+}
+
+// QueryIdentity queries the "identity" edge of the IdentityAdoptionDecision entity.
+func (_m *IdentityAdoptionDecision) QueryIdentity() *AuthIdentityQuery {
+ return NewIdentityAdoptionDecisionClient(_m.config).QueryIdentity(_m)
+}
+
+// Update returns a builder for updating this IdentityAdoptionDecision.
+// Note that you need to call IdentityAdoptionDecision.Unwrap() before calling this method if this IdentityAdoptionDecision
+// was returned from a transaction, and the transaction was committed or rolled back.
+func (_m *IdentityAdoptionDecision) Update() *IdentityAdoptionDecisionUpdateOne {
+ return NewIdentityAdoptionDecisionClient(_m.config).UpdateOne(_m)
+}
+
+// Unwrap unwraps the IdentityAdoptionDecision entity that was returned from a transaction after it was closed,
+// so that all future queries will be executed through the driver which created the transaction.
+func (_m *IdentityAdoptionDecision) Unwrap() *IdentityAdoptionDecision {
+ _tx, ok := _m.config.driver.(*txDriver)
+ if !ok {
+ panic("ent: IdentityAdoptionDecision is not a transactional entity")
+ }
+ _m.config.driver = _tx.drv
+ return _m
+}
+
+// String implements the fmt.Stringer.
+func (_m *IdentityAdoptionDecision) String() string {
+ var builder strings.Builder
+ builder.WriteString("IdentityAdoptionDecision(")
+ builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
+ builder.WriteString("created_at=")
+ builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("updated_at=")
+ builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("pending_auth_session_id=")
+ builder.WriteString(fmt.Sprintf("%v", _m.PendingAuthSessionID))
+ builder.WriteString(", ")
+ if v := _m.IdentityID; v != nil {
+ builder.WriteString("identity_id=")
+ builder.WriteString(fmt.Sprintf("%v", *v))
+ }
+ builder.WriteString(", ")
+ builder.WriteString("adopt_display_name=")
+ builder.WriteString(fmt.Sprintf("%v", _m.AdoptDisplayName))
+ builder.WriteString(", ")
+ builder.WriteString("adopt_avatar=")
+ builder.WriteString(fmt.Sprintf("%v", _m.AdoptAvatar))
+ builder.WriteString(", ")
+ builder.WriteString("decided_at=")
+ builder.WriteString(_m.DecidedAt.Format(time.ANSIC))
+ builder.WriteByte(')')
+ return builder.String()
+}
+
+// IdentityAdoptionDecisions is a parsable slice of IdentityAdoptionDecision.
+type IdentityAdoptionDecisions []*IdentityAdoptionDecision
diff --git a/backend/ent/identityadoptiondecision/identityadoptiondecision.go b/backend/ent/identityadoptiondecision/identityadoptiondecision.go
new file mode 100644
index 00000000..93adaf73
--- /dev/null
+++ b/backend/ent/identityadoptiondecision/identityadoptiondecision.go
@@ -0,0 +1,159 @@
+// Code generated by ent, DO NOT EDIT.
+
+package identityadoptiondecision
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+)
+
+const (
+ // Label holds the string label denoting the identityadoptiondecision type in the database.
+ Label = "identity_adoption_decision"
+ // FieldID holds the string denoting the id field in the database.
+ FieldID = "id"
+ // FieldCreatedAt holds the string denoting the created_at field in the database.
+ FieldCreatedAt = "created_at"
+ // FieldUpdatedAt holds the string denoting the updated_at field in the database.
+ FieldUpdatedAt = "updated_at"
+ // FieldPendingAuthSessionID holds the string denoting the pending_auth_session_id field in the database.
+ FieldPendingAuthSessionID = "pending_auth_session_id"
+ // FieldIdentityID holds the string denoting the identity_id field in the database.
+ FieldIdentityID = "identity_id"
+ // FieldAdoptDisplayName holds the string denoting the adopt_display_name field in the database.
+ FieldAdoptDisplayName = "adopt_display_name"
+ // FieldAdoptAvatar holds the string denoting the adopt_avatar field in the database.
+ FieldAdoptAvatar = "adopt_avatar"
+ // FieldDecidedAt holds the string denoting the decided_at field in the database.
+ FieldDecidedAt = "decided_at"
+ // EdgePendingAuthSession holds the string denoting the pending_auth_session edge name in mutations.
+ EdgePendingAuthSession = "pending_auth_session"
+ // EdgeIdentity holds the string denoting the identity edge name in mutations.
+ EdgeIdentity = "identity"
+ // Table holds the table name of the identityadoptiondecision in the database.
+ Table = "identity_adoption_decisions"
+ // PendingAuthSessionTable is the table that holds the pending_auth_session relation/edge.
+ PendingAuthSessionTable = "identity_adoption_decisions"
+ // PendingAuthSessionInverseTable is the table name for the PendingAuthSession entity.
+ // It exists in this package in order to avoid circular dependency with the "pendingauthsession" package.
+ PendingAuthSessionInverseTable = "pending_auth_sessions"
+ // PendingAuthSessionColumn is the table column denoting the pending_auth_session relation/edge.
+ PendingAuthSessionColumn = "pending_auth_session_id"
+ // IdentityTable is the table that holds the identity relation/edge.
+ IdentityTable = "identity_adoption_decisions"
+ // IdentityInverseTable is the table name for the AuthIdentity entity.
+ // It exists in this package in order to avoid circular dependency with the "authidentity" package.
+ IdentityInverseTable = "auth_identities"
+ // IdentityColumn is the table column denoting the identity relation/edge.
+ IdentityColumn = "identity_id"
+)
+
+// Columns holds all SQL columns for identityadoptiondecision fields.
+var Columns = []string{
+ FieldID,
+ FieldCreatedAt,
+ FieldUpdatedAt,
+ FieldPendingAuthSessionID,
+ FieldIdentityID,
+ FieldAdoptDisplayName,
+ FieldAdoptAvatar,
+ FieldDecidedAt,
+}
+
+// ValidColumn reports if the column name is valid (part of the table columns).
+func ValidColumn(column string) bool {
+ for i := range Columns {
+ if column == Columns[i] {
+ return true
+ }
+ }
+ return false
+}
+
+var (
+ // DefaultCreatedAt holds the default value on creation for the "created_at" field.
+ DefaultCreatedAt func() time.Time
+ // DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
+ DefaultUpdatedAt func() time.Time
+ // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
+ UpdateDefaultUpdatedAt func() time.Time
+ // DefaultAdoptDisplayName holds the default value on creation for the "adopt_display_name" field.
+ DefaultAdoptDisplayName bool
+ // DefaultAdoptAvatar holds the default value on creation for the "adopt_avatar" field.
+ DefaultAdoptAvatar bool
+ // DefaultDecidedAt holds the default value on creation for the "decided_at" field.
+ DefaultDecidedAt func() time.Time
+)
+
+// OrderOption defines the ordering options for the IdentityAdoptionDecision queries.
+type OrderOption func(*sql.Selector)
+
+// ByID orders the results by the id field.
+func ByID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldID, opts...).ToFunc()
+}
+
+// ByCreatedAt orders the results by the created_at field.
+func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
+}
+
+// ByUpdatedAt orders the results by the updated_at field.
+func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
+}
+
+// ByPendingAuthSessionID orders the results by the pending_auth_session_id field.
+func ByPendingAuthSessionID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldPendingAuthSessionID, opts...).ToFunc()
+}
+
+// ByIdentityID orders the results by the identity_id field.
+func ByIdentityID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldIdentityID, opts...).ToFunc()
+}
+
+// ByAdoptDisplayName orders the results by the adopt_display_name field.
+func ByAdoptDisplayName(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldAdoptDisplayName, opts...).ToFunc()
+}
+
+// ByAdoptAvatar orders the results by the adopt_avatar field.
+func ByAdoptAvatar(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldAdoptAvatar, opts...).ToFunc()
+}
+
+// ByDecidedAt orders the results by the decided_at field.
+func ByDecidedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldDecidedAt, opts...).ToFunc()
+}
+
+// ByPendingAuthSessionField orders the results by pending_auth_session field.
+func ByPendingAuthSessionField(field string, opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newPendingAuthSessionStep(), sql.OrderByField(field, opts...))
+ }
+}
+
+// ByIdentityField orders the results by identity field.
+func ByIdentityField(field string, opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newIdentityStep(), sql.OrderByField(field, opts...))
+ }
+}
+func newPendingAuthSessionStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(PendingAuthSessionInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.O2O, true, PendingAuthSessionTable, PendingAuthSessionColumn),
+ )
+}
+func newIdentityStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(IdentityInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, IdentityTable, IdentityColumn),
+ )
+}
diff --git a/backend/ent/identityadoptiondecision/where.go b/backend/ent/identityadoptiondecision/where.go
new file mode 100644
index 00000000..1968f175
--- /dev/null
+++ b/backend/ent/identityadoptiondecision/where.go
@@ -0,0 +1,342 @@
+// Code generated by ent, DO NOT EDIT.
+
+package identityadoptiondecision
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ID filters vertices based on their ID field.
+func ID(id int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldID, id))
+}
+
+// IDEQ applies the EQ predicate on the ID field.
+func IDEQ(id int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldID, id))
+}
+
+// IDNEQ applies the NEQ predicate on the ID field.
+func IDNEQ(id int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldID, id))
+}
+
+// IDIn applies the In predicate on the ID field.
+func IDIn(ids ...int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldID, ids...))
+}
+
+// IDNotIn applies the NotIn predicate on the ID field.
+func IDNotIn(ids ...int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldID, ids...))
+}
+
+// IDGT applies the GT predicate on the ID field.
+func IDGT(id int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldGT(FieldID, id))
+}
+
+// IDGTE applies the GTE predicate on the ID field.
+func IDGTE(id int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldGTE(FieldID, id))
+}
+
+// IDLT applies the LT predicate on the ID field.
+func IDLT(id int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldLT(FieldID, id))
+}
+
+// IDLTE applies the LTE predicate on the ID field.
+func IDLTE(id int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldLTE(FieldID, id))
+}
+
+// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
+func CreatedAt(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
+func UpdatedAt(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// PendingAuthSessionID applies equality check predicate on the "pending_auth_session_id" field. It's identical to PendingAuthSessionIDEQ.
+func PendingAuthSessionID(v int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldPendingAuthSessionID, v))
+}
+
+// IdentityID applies equality check predicate on the "identity_id" field. It's identical to IdentityIDEQ.
+func IdentityID(v int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldIdentityID, v))
+}
+
+// AdoptDisplayName applies equality check predicate on the "adopt_display_name" field. It's identical to AdoptDisplayNameEQ.
+func AdoptDisplayName(v bool) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldAdoptDisplayName, v))
+}
+
+// AdoptAvatar applies equality check predicate on the "adopt_avatar" field. It's identical to AdoptAvatarEQ.
+func AdoptAvatar(v bool) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldAdoptAvatar, v))
+}
+
+// DecidedAt applies equality check predicate on the "decided_at" field. It's identical to DecidedAtEQ.
+func DecidedAt(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldDecidedAt, v))
+}
+
+// CreatedAtEQ applies the EQ predicate on the "created_at" field.
+func CreatedAtEQ(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
+func CreatedAtNEQ(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtIn applies the In predicate on the "created_at" field.
+func CreatedAtIn(vs ...time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
+func CreatedAtNotIn(vs ...time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtGT applies the GT predicate on the "created_at" field.
+func CreatedAtGT(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldGT(FieldCreatedAt, v))
+}
+
+// CreatedAtGTE applies the GTE predicate on the "created_at" field.
+func CreatedAtGTE(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldGTE(FieldCreatedAt, v))
+}
+
+// CreatedAtLT applies the LT predicate on the "created_at" field.
+func CreatedAtLT(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldLT(FieldCreatedAt, v))
+}
+
+// CreatedAtLTE applies the LTE predicate on the "created_at" field.
+func CreatedAtLTE(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldLTE(FieldCreatedAt, v))
+}
+
+// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
+func UpdatedAtEQ(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
+func UpdatedAtNEQ(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtIn applies the In predicate on the "updated_at" field.
+func UpdatedAtIn(vs ...time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
+func UpdatedAtNotIn(vs ...time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtGT applies the GT predicate on the "updated_at" field.
+func UpdatedAtGT(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldGT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
+func UpdatedAtGTE(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldGTE(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLT applies the LT predicate on the "updated_at" field.
+func UpdatedAtLT(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldLT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
+func UpdatedAtLTE(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldLTE(FieldUpdatedAt, v))
+}
+
+// PendingAuthSessionIDEQ applies the EQ predicate on the "pending_auth_session_id" field.
+func PendingAuthSessionIDEQ(v int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldPendingAuthSessionID, v))
+}
+
+// PendingAuthSessionIDNEQ applies the NEQ predicate on the "pending_auth_session_id" field.
+func PendingAuthSessionIDNEQ(v int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldPendingAuthSessionID, v))
+}
+
+// PendingAuthSessionIDIn applies the In predicate on the "pending_auth_session_id" field.
+func PendingAuthSessionIDIn(vs ...int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldPendingAuthSessionID, vs...))
+}
+
+// PendingAuthSessionIDNotIn applies the NotIn predicate on the "pending_auth_session_id" field.
+func PendingAuthSessionIDNotIn(vs ...int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldPendingAuthSessionID, vs...))
+}
+
+// IdentityIDEQ applies the EQ predicate on the "identity_id" field.
+func IdentityIDEQ(v int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldIdentityID, v))
+}
+
+// IdentityIDNEQ applies the NEQ predicate on the "identity_id" field.
+func IdentityIDNEQ(v int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldIdentityID, v))
+}
+
+// IdentityIDIn applies the In predicate on the "identity_id" field.
+func IdentityIDIn(vs ...int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldIdentityID, vs...))
+}
+
+// IdentityIDNotIn applies the NotIn predicate on the "identity_id" field.
+func IdentityIDNotIn(vs ...int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldIdentityID, vs...))
+}
+
+// IdentityIDIsNil applies the IsNil predicate on the "identity_id" field.
+func IdentityIDIsNil() predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldIsNull(FieldIdentityID))
+}
+
+// IdentityIDNotNil applies the NotNil predicate on the "identity_id" field.
+func IdentityIDNotNil() predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNotNull(FieldIdentityID))
+}
+
+// AdoptDisplayNameEQ applies the EQ predicate on the "adopt_display_name" field.
+func AdoptDisplayNameEQ(v bool) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldAdoptDisplayName, v))
+}
+
+// AdoptDisplayNameNEQ applies the NEQ predicate on the "adopt_display_name" field.
+func AdoptDisplayNameNEQ(v bool) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldAdoptDisplayName, v))
+}
+
+// AdoptAvatarEQ applies the EQ predicate on the "adopt_avatar" field.
+func AdoptAvatarEQ(v bool) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldAdoptAvatar, v))
+}
+
+// AdoptAvatarNEQ applies the NEQ predicate on the "adopt_avatar" field.
+func AdoptAvatarNEQ(v bool) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldAdoptAvatar, v))
+}
+
+// DecidedAtEQ applies the EQ predicate on the "decided_at" field.
+func DecidedAtEQ(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldDecidedAt, v))
+}
+
+// DecidedAtNEQ applies the NEQ predicate on the "decided_at" field.
+func DecidedAtNEQ(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldDecidedAt, v))
+}
+
+// DecidedAtIn applies the In predicate on the "decided_at" field.
+func DecidedAtIn(vs ...time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldDecidedAt, vs...))
+}
+
+// DecidedAtNotIn applies the NotIn predicate on the "decided_at" field.
+func DecidedAtNotIn(vs ...time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldDecidedAt, vs...))
+}
+
+// DecidedAtGT applies the GT predicate on the "decided_at" field.
+func DecidedAtGT(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldGT(FieldDecidedAt, v))
+}
+
+// DecidedAtGTE applies the GTE predicate on the "decided_at" field.
+func DecidedAtGTE(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldGTE(FieldDecidedAt, v))
+}
+
+// DecidedAtLT applies the LT predicate on the "decided_at" field.
+func DecidedAtLT(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldLT(FieldDecidedAt, v))
+}
+
+// DecidedAtLTE applies the LTE predicate on the "decided_at" field.
+func DecidedAtLTE(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldLTE(FieldDecidedAt, v))
+}
+
+// HasPendingAuthSession applies the HasEdge predicate on the "pending_auth_session" edge.
+func HasPendingAuthSession() predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.O2O, true, PendingAuthSessionTable, PendingAuthSessionColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasPendingAuthSessionWith applies the HasEdge predicate on the "pending_auth_session" edge with a given conditions (other predicates).
+func HasPendingAuthSessionWith(preds ...predicate.PendingAuthSession) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(func(s *sql.Selector) {
+ step := newPendingAuthSessionStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// HasIdentity applies the HasEdge predicate on the "identity" edge.
+func HasIdentity() predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, IdentityTable, IdentityColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasIdentityWith applies the HasEdge predicate on the "identity" edge with a given conditions (other predicates).
+func HasIdentityWith(preds ...predicate.AuthIdentity) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(func(s *sql.Selector) {
+ step := newIdentityStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// And groups predicates with the AND operator between them.
+func And(predicates ...predicate.IdentityAdoptionDecision) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.AndPredicates(predicates...))
+}
+
+// Or groups predicates with the OR operator between them.
+func Or(predicates ...predicate.IdentityAdoptionDecision) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.OrPredicates(predicates...))
+}
+
+// Not applies the not operator on the given predicate.
+func Not(p predicate.IdentityAdoptionDecision) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.NotPredicates(p))
+}
diff --git a/backend/ent/identityadoptiondecision_create.go b/backend/ent/identityadoptiondecision_create.go
new file mode 100644
index 00000000..491ba9f9
--- /dev/null
+++ b/backend/ent/identityadoptiondecision_create.go
@@ -0,0 +1,843 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+)
+
+// IdentityAdoptionDecisionCreate is the builder for creating a IdentityAdoptionDecision entity.
+type IdentityAdoptionDecisionCreate struct {
+ config
+ mutation *IdentityAdoptionDecisionMutation
+ hooks []Hook
+ conflict []sql.ConflictOption
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (_c *IdentityAdoptionDecisionCreate) SetCreatedAt(v time.Time) *IdentityAdoptionDecisionCreate {
+ _c.mutation.SetCreatedAt(v)
+ return _c
+}
+
+// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
+func (_c *IdentityAdoptionDecisionCreate) SetNillableCreatedAt(v *time.Time) *IdentityAdoptionDecisionCreate {
+ if v != nil {
+ _c.SetCreatedAt(*v)
+ }
+ return _c
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_c *IdentityAdoptionDecisionCreate) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionCreate {
+ _c.mutation.SetUpdatedAt(v)
+ return _c
+}
+
+// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil.
+func (_c *IdentityAdoptionDecisionCreate) SetNillableUpdatedAt(v *time.Time) *IdentityAdoptionDecisionCreate {
+ if v != nil {
+ _c.SetUpdatedAt(*v)
+ }
+ return _c
+}
+
+// SetPendingAuthSessionID sets the "pending_auth_session_id" field.
+func (_c *IdentityAdoptionDecisionCreate) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionCreate {
+ _c.mutation.SetPendingAuthSessionID(v)
+ return _c
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (_c *IdentityAdoptionDecisionCreate) SetIdentityID(v int64) *IdentityAdoptionDecisionCreate {
+ _c.mutation.SetIdentityID(v)
+ return _c
+}
+
+// SetNillableIdentityID sets the "identity_id" field if the given value is not nil.
+func (_c *IdentityAdoptionDecisionCreate) SetNillableIdentityID(v *int64) *IdentityAdoptionDecisionCreate {
+ if v != nil {
+ _c.SetIdentityID(*v)
+ }
+ return _c
+}
+
+// SetAdoptDisplayName sets the "adopt_display_name" field.
+func (_c *IdentityAdoptionDecisionCreate) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionCreate {
+ _c.mutation.SetAdoptDisplayName(v)
+ return _c
+}
+
+// SetNillableAdoptDisplayName sets the "adopt_display_name" field if the given value is not nil.
+func (_c *IdentityAdoptionDecisionCreate) SetNillableAdoptDisplayName(v *bool) *IdentityAdoptionDecisionCreate {
+ if v != nil {
+ _c.SetAdoptDisplayName(*v)
+ }
+ return _c
+}
+
+// SetAdoptAvatar sets the "adopt_avatar" field.
+func (_c *IdentityAdoptionDecisionCreate) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionCreate {
+ _c.mutation.SetAdoptAvatar(v)
+ return _c
+}
+
+// SetNillableAdoptAvatar sets the "adopt_avatar" field if the given value is not nil.
+func (_c *IdentityAdoptionDecisionCreate) SetNillableAdoptAvatar(v *bool) *IdentityAdoptionDecisionCreate {
+ if v != nil {
+ _c.SetAdoptAvatar(*v)
+ }
+ return _c
+}
+
+// SetDecidedAt sets the "decided_at" field.
+func (_c *IdentityAdoptionDecisionCreate) SetDecidedAt(v time.Time) *IdentityAdoptionDecisionCreate {
+ _c.mutation.SetDecidedAt(v)
+ return _c
+}
+
+// SetNillableDecidedAt sets the "decided_at" field if the given value is not nil.
+func (_c *IdentityAdoptionDecisionCreate) SetNillableDecidedAt(v *time.Time) *IdentityAdoptionDecisionCreate {
+ if v != nil {
+ _c.SetDecidedAt(*v)
+ }
+ return _c
+}
+
+// SetPendingAuthSession sets the "pending_auth_session" edge to the PendingAuthSession entity.
+func (_c *IdentityAdoptionDecisionCreate) SetPendingAuthSession(v *PendingAuthSession) *IdentityAdoptionDecisionCreate {
+ return _c.SetPendingAuthSessionID(v.ID)
+}
+
+// SetIdentity sets the "identity" edge to the AuthIdentity entity.
+func (_c *IdentityAdoptionDecisionCreate) SetIdentity(v *AuthIdentity) *IdentityAdoptionDecisionCreate {
+ return _c.SetIdentityID(v.ID)
+}
+
+// Mutation returns the IdentityAdoptionDecisionMutation object of the builder.
+func (_c *IdentityAdoptionDecisionCreate) Mutation() *IdentityAdoptionDecisionMutation {
+ return _c.mutation
+}
+
+// Save creates the IdentityAdoptionDecision in the database.
+func (_c *IdentityAdoptionDecisionCreate) Save(ctx context.Context) (*IdentityAdoptionDecision, error) {
+ _c.defaults()
+ return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
+}
+
+// SaveX calls Save and panics if Save returns an error.
+func (_c *IdentityAdoptionDecisionCreate) SaveX(ctx context.Context) *IdentityAdoptionDecision {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *IdentityAdoptionDecisionCreate) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *IdentityAdoptionDecisionCreate) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_c *IdentityAdoptionDecisionCreate) defaults() {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ v := identityadoptiondecision.DefaultCreatedAt()
+ _c.mutation.SetCreatedAt(v)
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ v := identityadoptiondecision.DefaultUpdatedAt()
+ _c.mutation.SetUpdatedAt(v)
+ }
+ if _, ok := _c.mutation.AdoptDisplayName(); !ok {
+ v := identityadoptiondecision.DefaultAdoptDisplayName
+ _c.mutation.SetAdoptDisplayName(v)
+ }
+ if _, ok := _c.mutation.AdoptAvatar(); !ok {
+ v := identityadoptiondecision.DefaultAdoptAvatar
+ _c.mutation.SetAdoptAvatar(v)
+ }
+ if _, ok := _c.mutation.DecidedAt(); !ok {
+ v := identityadoptiondecision.DefaultDecidedAt()
+ _c.mutation.SetDecidedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_c *IdentityAdoptionDecisionCreate) check() error {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.created_at"`)}
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.updated_at"`)}
+ }
+ if _, ok := _c.mutation.PendingAuthSessionID(); !ok {
+ return &ValidationError{Name: "pending_auth_session_id", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.pending_auth_session_id"`)}
+ }
+ if _, ok := _c.mutation.AdoptDisplayName(); !ok {
+ return &ValidationError{Name: "adopt_display_name", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.adopt_display_name"`)}
+ }
+ if _, ok := _c.mutation.AdoptAvatar(); !ok {
+ return &ValidationError{Name: "adopt_avatar", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.adopt_avatar"`)}
+ }
+ if _, ok := _c.mutation.DecidedAt(); !ok {
+ return &ValidationError{Name: "decided_at", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.decided_at"`)}
+ }
+ if len(_c.mutation.PendingAuthSessionIDs()) == 0 {
+ return &ValidationError{Name: "pending_auth_session", err: errors.New(`ent: missing required edge "IdentityAdoptionDecision.pending_auth_session"`)}
+ }
+ return nil
+}
+
+func (_c *IdentityAdoptionDecisionCreate) sqlSave(ctx context.Context) (*IdentityAdoptionDecision, error) {
+ if err := _c.check(); err != nil {
+ return nil, err
+ }
+ _node, _spec := _c.createSpec()
+ if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ id := _spec.ID.Value.(int64)
+ _node.ID = int64(id)
+ _c.mutation.id = &_node.ID
+ _c.mutation.done = true
+ return _node, nil
+}
+
+func (_c *IdentityAdoptionDecisionCreate) createSpec() (*IdentityAdoptionDecision, *sqlgraph.CreateSpec) {
+ var (
+ _node = &IdentityAdoptionDecision{config: _c.config}
+ _spec = sqlgraph.NewCreateSpec(identityadoptiondecision.Table, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64))
+ )
+ _spec.OnConflict = _c.conflict
+ if value, ok := _c.mutation.CreatedAt(); ok {
+ _spec.SetField(identityadoptiondecision.FieldCreatedAt, field.TypeTime, value)
+ _node.CreatedAt = value
+ }
+ if value, ok := _c.mutation.UpdatedAt(); ok {
+ _spec.SetField(identityadoptiondecision.FieldUpdatedAt, field.TypeTime, value)
+ _node.UpdatedAt = value
+ }
+ if value, ok := _c.mutation.AdoptDisplayName(); ok {
+ _spec.SetField(identityadoptiondecision.FieldAdoptDisplayName, field.TypeBool, value)
+ _node.AdoptDisplayName = value
+ }
+ if value, ok := _c.mutation.AdoptAvatar(); ok {
+ _spec.SetField(identityadoptiondecision.FieldAdoptAvatar, field.TypeBool, value)
+ _node.AdoptAvatar = value
+ }
+ if value, ok := _c.mutation.DecidedAt(); ok {
+ _spec.SetField(identityadoptiondecision.FieldDecidedAt, field.TypeTime, value)
+ _node.DecidedAt = value
+ }
+ if nodes := _c.mutation.PendingAuthSessionIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: true,
+ Table: identityadoptiondecision.PendingAuthSessionTable,
+ Columns: []string{identityadoptiondecision.PendingAuthSessionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _node.PendingAuthSessionID = nodes[0]
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ if nodes := _c.mutation.IdentityIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: identityadoptiondecision.IdentityTable,
+ Columns: []string{identityadoptiondecision.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _node.IdentityID = &nodes[0]
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ return _node, _spec
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.IdentityAdoptionDecision.Create().
+// SetCreatedAt(v).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.IdentityAdoptionDecisionUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *IdentityAdoptionDecisionCreate) OnConflict(opts ...sql.ConflictOption) *IdentityAdoptionDecisionUpsertOne {
+ _c.conflict = opts
+ return &IdentityAdoptionDecisionUpsertOne{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.IdentityAdoptionDecision.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *IdentityAdoptionDecisionCreate) OnConflictColumns(columns ...string) *IdentityAdoptionDecisionUpsertOne {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &IdentityAdoptionDecisionUpsertOne{
+ create: _c,
+ }
+}
+
+type (
+ // IdentityAdoptionDecisionUpsertOne is the builder for "upsert"-ing
+ // one IdentityAdoptionDecision node.
+ IdentityAdoptionDecisionUpsertOne struct {
+ create *IdentityAdoptionDecisionCreate
+ }
+
+ // IdentityAdoptionDecisionUpsert is the "OnConflict" setter.
+ IdentityAdoptionDecisionUpsert struct {
+ *sql.UpdateSet
+ }
+)
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *IdentityAdoptionDecisionUpsert) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpsert {
+ u.Set(identityadoptiondecision.FieldUpdatedAt, v)
+ return u
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsert) UpdateUpdatedAt() *IdentityAdoptionDecisionUpsert {
+ u.SetExcluded(identityadoptiondecision.FieldUpdatedAt)
+ return u
+}
+
+// SetPendingAuthSessionID sets the "pending_auth_session_id" field.
+func (u *IdentityAdoptionDecisionUpsert) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpsert {
+ u.Set(identityadoptiondecision.FieldPendingAuthSessionID, v)
+ return u
+}
+
+// UpdatePendingAuthSessionID sets the "pending_auth_session_id" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsert) UpdatePendingAuthSessionID() *IdentityAdoptionDecisionUpsert {
+ u.SetExcluded(identityadoptiondecision.FieldPendingAuthSessionID)
+ return u
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (u *IdentityAdoptionDecisionUpsert) SetIdentityID(v int64) *IdentityAdoptionDecisionUpsert {
+ u.Set(identityadoptiondecision.FieldIdentityID, v)
+ return u
+}
+
+// UpdateIdentityID sets the "identity_id" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsert) UpdateIdentityID() *IdentityAdoptionDecisionUpsert {
+ u.SetExcluded(identityadoptiondecision.FieldIdentityID)
+ return u
+}
+
+// ClearIdentityID clears the value of the "identity_id" field.
+func (u *IdentityAdoptionDecisionUpsert) ClearIdentityID() *IdentityAdoptionDecisionUpsert {
+ u.SetNull(identityadoptiondecision.FieldIdentityID)
+ return u
+}
+
+// SetAdoptDisplayName sets the "adopt_display_name" field.
+func (u *IdentityAdoptionDecisionUpsert) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpsert {
+ u.Set(identityadoptiondecision.FieldAdoptDisplayName, v)
+ return u
+}
+
+// UpdateAdoptDisplayName sets the "adopt_display_name" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsert) UpdateAdoptDisplayName() *IdentityAdoptionDecisionUpsert {
+ u.SetExcluded(identityadoptiondecision.FieldAdoptDisplayName)
+ return u
+}
+
+// SetAdoptAvatar sets the "adopt_avatar" field.
+func (u *IdentityAdoptionDecisionUpsert) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpsert {
+ u.Set(identityadoptiondecision.FieldAdoptAvatar, v)
+ return u
+}
+
+// UpdateAdoptAvatar sets the "adopt_avatar" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsert) UpdateAdoptAvatar() *IdentityAdoptionDecisionUpsert {
+ u.SetExcluded(identityadoptiondecision.FieldAdoptAvatar)
+ return u
+}
+
+// UpdateNewValues updates the mutable fields using the new values that were set on create.
+// Using this option is equivalent to using:
+//
+// client.IdentityAdoptionDecision.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *IdentityAdoptionDecisionUpsertOne) UpdateNewValues() *IdentityAdoptionDecisionUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ if _, exists := u.create.mutation.CreatedAt(); exists {
+ s.SetIgnore(identityadoptiondecision.FieldCreatedAt)
+ }
+ if _, exists := u.create.mutation.DecidedAt(); exists {
+ s.SetIgnore(identityadoptiondecision.FieldDecidedAt)
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.IdentityAdoptionDecision.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *IdentityAdoptionDecisionUpsertOne) Ignore() *IdentityAdoptionDecisionUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *IdentityAdoptionDecisionUpsertOne) DoNothing() *IdentityAdoptionDecisionUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the IdentityAdoptionDecisionCreate.OnConflict
+// documentation for more info.
+func (u *IdentityAdoptionDecisionUpsertOne) Update(set func(*IdentityAdoptionDecisionUpsert)) *IdentityAdoptionDecisionUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&IdentityAdoptionDecisionUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *IdentityAdoptionDecisionUpsertOne) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertOne) UpdateUpdatedAt() *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetPendingAuthSessionID sets the "pending_auth_session_id" field.
+func (u *IdentityAdoptionDecisionUpsertOne) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetPendingAuthSessionID(v)
+ })
+}
+
+// UpdatePendingAuthSessionID sets the "pending_auth_session_id" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertOne) UpdatePendingAuthSessionID() *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdatePendingAuthSessionID()
+ })
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (u *IdentityAdoptionDecisionUpsertOne) SetIdentityID(v int64) *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetIdentityID(v)
+ })
+}
+
+// UpdateIdentityID sets the "identity_id" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertOne) UpdateIdentityID() *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdateIdentityID()
+ })
+}
+
+// ClearIdentityID clears the value of the "identity_id" field.
+func (u *IdentityAdoptionDecisionUpsertOne) ClearIdentityID() *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.ClearIdentityID()
+ })
+}
+
+// SetAdoptDisplayName sets the "adopt_display_name" field.
+func (u *IdentityAdoptionDecisionUpsertOne) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetAdoptDisplayName(v)
+ })
+}
+
+// UpdateAdoptDisplayName sets the "adopt_display_name" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertOne) UpdateAdoptDisplayName() *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdateAdoptDisplayName()
+ })
+}
+
+// SetAdoptAvatar sets the "adopt_avatar" field.
+func (u *IdentityAdoptionDecisionUpsertOne) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetAdoptAvatar(v)
+ })
+}
+
+// UpdateAdoptAvatar sets the "adopt_avatar" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertOne) UpdateAdoptAvatar() *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdateAdoptAvatar()
+ })
+}
+
+// Exec executes the query.
+func (u *IdentityAdoptionDecisionUpsertOne) Exec(ctx context.Context) error {
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for IdentityAdoptionDecisionCreate.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *IdentityAdoptionDecisionUpsertOne) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// Exec executes the UPSERT query and returns the inserted/updated ID.
+func (u *IdentityAdoptionDecisionUpsertOne) ID(ctx context.Context) (id int64, err error) {
+ node, err := u.create.Save(ctx)
+ if err != nil {
+ return id, err
+ }
+ return node.ID, nil
+}
+
+// IDX is like ID, but panics if an error occurs.
+func (u *IdentityAdoptionDecisionUpsertOne) IDX(ctx context.Context) int64 {
+ id, err := u.ID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// IdentityAdoptionDecisionCreateBulk is the builder for creating many IdentityAdoptionDecision entities in bulk.
+type IdentityAdoptionDecisionCreateBulk struct {
+ config
+ err error
+ builders []*IdentityAdoptionDecisionCreate
+ conflict []sql.ConflictOption
+}
+
+// Save creates the IdentityAdoptionDecision entities in the database.
+func (_c *IdentityAdoptionDecisionCreateBulk) Save(ctx context.Context) ([]*IdentityAdoptionDecision, error) {
+ if _c.err != nil {
+ return nil, _c.err
+ }
+ specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
+ nodes := make([]*IdentityAdoptionDecision, len(_c.builders))
+ mutators := make([]Mutator, len(_c.builders))
+ for i := range _c.builders {
+ func(i int, root context.Context) {
+ builder := _c.builders[i]
+ builder.defaults()
+ var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
+ mutation, ok := m.(*IdentityAdoptionDecisionMutation)
+ if !ok {
+ return nil, fmt.Errorf("unexpected mutation type %T", m)
+ }
+ if err := builder.check(); err != nil {
+ return nil, err
+ }
+ builder.mutation = mutation
+ var err error
+ nodes[i], specs[i] = builder.createSpec()
+ if i < len(mutators)-1 {
+ _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
+ } else {
+ spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
+ spec.OnConflict = _c.conflict
+ // Invoke the actual operation on the latest mutation in the chain.
+ if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ }
+ }
+ if err != nil {
+ return nil, err
+ }
+ mutation.id = &nodes[i].ID
+ if specs[i].ID.Value != nil {
+ id := specs[i].ID.Value.(int64)
+ nodes[i].ID = int64(id)
+ }
+ mutation.done = true
+ return nodes[i], nil
+ })
+ for i := len(builder.hooks) - 1; i >= 0; i-- {
+ mut = builder.hooks[i](mut)
+ }
+ mutators[i] = mut
+ }(i, ctx)
+ }
+ if len(mutators) > 0 {
+ if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_c *IdentityAdoptionDecisionCreateBulk) SaveX(ctx context.Context) []*IdentityAdoptionDecision {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *IdentityAdoptionDecisionCreateBulk) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *IdentityAdoptionDecisionCreateBulk) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.IdentityAdoptionDecision.CreateBulk(builders...).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.IdentityAdoptionDecisionUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *IdentityAdoptionDecisionCreateBulk) OnConflict(opts ...sql.ConflictOption) *IdentityAdoptionDecisionUpsertBulk {
+ _c.conflict = opts
+ return &IdentityAdoptionDecisionUpsertBulk{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.IdentityAdoptionDecision.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *IdentityAdoptionDecisionCreateBulk) OnConflictColumns(columns ...string) *IdentityAdoptionDecisionUpsertBulk {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &IdentityAdoptionDecisionUpsertBulk{
+ create: _c,
+ }
+}
+
+// IdentityAdoptionDecisionUpsertBulk is the builder for "upsert"-ing
+// a bulk of IdentityAdoptionDecision nodes.
+type IdentityAdoptionDecisionUpsertBulk struct {
+ create *IdentityAdoptionDecisionCreateBulk
+}
+
+// UpdateNewValues updates the mutable fields using the new values that
+// were set on create. Using this option is equivalent to using:
+//
+// client.IdentityAdoptionDecision.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *IdentityAdoptionDecisionUpsertBulk) UpdateNewValues() *IdentityAdoptionDecisionUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ for _, b := range u.create.builders {
+ if _, exists := b.mutation.CreatedAt(); exists {
+ s.SetIgnore(identityadoptiondecision.FieldCreatedAt)
+ }
+ if _, exists := b.mutation.DecidedAt(); exists {
+ s.SetIgnore(identityadoptiondecision.FieldDecidedAt)
+ }
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.IdentityAdoptionDecision.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *IdentityAdoptionDecisionUpsertBulk) Ignore() *IdentityAdoptionDecisionUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *IdentityAdoptionDecisionUpsertBulk) DoNothing() *IdentityAdoptionDecisionUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the IdentityAdoptionDecisionCreateBulk.OnConflict
+// documentation for more info.
+func (u *IdentityAdoptionDecisionUpsertBulk) Update(set func(*IdentityAdoptionDecisionUpsert)) *IdentityAdoptionDecisionUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&IdentityAdoptionDecisionUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *IdentityAdoptionDecisionUpsertBulk) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertBulk) UpdateUpdatedAt() *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetPendingAuthSessionID sets the "pending_auth_session_id" field.
+func (u *IdentityAdoptionDecisionUpsertBulk) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetPendingAuthSessionID(v)
+ })
+}
+
+// UpdatePendingAuthSessionID sets the "pending_auth_session_id" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertBulk) UpdatePendingAuthSessionID() *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdatePendingAuthSessionID()
+ })
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (u *IdentityAdoptionDecisionUpsertBulk) SetIdentityID(v int64) *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetIdentityID(v)
+ })
+}
+
+// UpdateIdentityID sets the "identity_id" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertBulk) UpdateIdentityID() *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdateIdentityID()
+ })
+}
+
+// ClearIdentityID clears the value of the "identity_id" field.
+func (u *IdentityAdoptionDecisionUpsertBulk) ClearIdentityID() *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.ClearIdentityID()
+ })
+}
+
+// SetAdoptDisplayName sets the "adopt_display_name" field.
+func (u *IdentityAdoptionDecisionUpsertBulk) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetAdoptDisplayName(v)
+ })
+}
+
+// UpdateAdoptDisplayName sets the "adopt_display_name" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertBulk) UpdateAdoptDisplayName() *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdateAdoptDisplayName()
+ })
+}
+
+// SetAdoptAvatar sets the "adopt_avatar" field.
+func (u *IdentityAdoptionDecisionUpsertBulk) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetAdoptAvatar(v)
+ })
+}
+
+// UpdateAdoptAvatar sets the "adopt_avatar" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertBulk) UpdateAdoptAvatar() *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdateAdoptAvatar()
+ })
+}
+
+// Exec executes the query.
+func (u *IdentityAdoptionDecisionUpsertBulk) Exec(ctx context.Context) error {
+ if u.create.err != nil {
+ return u.create.err
+ }
+ for i, b := range u.create.builders {
+ if len(b.conflict) != 0 {
+ return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the IdentityAdoptionDecisionCreateBulk instead", i)
+ }
+ }
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for IdentityAdoptionDecisionCreateBulk.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *IdentityAdoptionDecisionUpsertBulk) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/identityadoptiondecision_delete.go b/backend/ent/identityadoptiondecision_delete.go
new file mode 100644
index 00000000..ef3d328d
--- /dev/null
+++ b/backend/ent/identityadoptiondecision_delete.go
@@ -0,0 +1,88 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// IdentityAdoptionDecisionDelete is the builder for deleting a IdentityAdoptionDecision entity.
+type IdentityAdoptionDecisionDelete struct {
+ config
+ hooks []Hook
+ mutation *IdentityAdoptionDecisionMutation
+}
+
+// Where appends a list predicates to the IdentityAdoptionDecisionDelete builder.
+func (_d *IdentityAdoptionDecisionDelete) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionDelete {
+ _d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query and returns how many vertices were deleted.
+func (_d *IdentityAdoptionDecisionDelete) Exec(ctx context.Context) (int, error) {
+ return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *IdentityAdoptionDecisionDelete) ExecX(ctx context.Context) int {
+ n, err := _d.Exec(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return n
+}
+
+func (_d *IdentityAdoptionDecisionDelete) sqlExec(ctx context.Context) (int, error) {
+ _spec := sqlgraph.NewDeleteSpec(identityadoptiondecision.Table, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64))
+ if ps := _d.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
+ if err != nil && sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ _d.mutation.done = true
+ return affected, err
+}
+
+// IdentityAdoptionDecisionDeleteOne is the builder for deleting a single IdentityAdoptionDecision entity.
+type IdentityAdoptionDecisionDeleteOne struct {
+ _d *IdentityAdoptionDecisionDelete
+}
+
+// Where appends a list predicates to the IdentityAdoptionDecisionDelete builder.
+func (_d *IdentityAdoptionDecisionDeleteOne) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionDeleteOne {
+ _d._d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query.
+func (_d *IdentityAdoptionDecisionDeleteOne) Exec(ctx context.Context) error {
+ n, err := _d._d.Exec(ctx)
+ switch {
+ case err != nil:
+ return err
+ case n == 0:
+ return &NotFoundError{identityadoptiondecision.Label}
+ default:
+ return nil
+ }
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *IdentityAdoptionDecisionDeleteOne) ExecX(ctx context.Context) {
+ if err := _d.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/identityadoptiondecision_query.go b/backend/ent/identityadoptiondecision_query.go
new file mode 100644
index 00000000..4082d8ee
--- /dev/null
+++ b/backend/ent/identityadoptiondecision_query.go
@@ -0,0 +1,721 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "fmt"
+ "math"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// IdentityAdoptionDecisionQuery is the builder for querying IdentityAdoptionDecision entities.
+type IdentityAdoptionDecisionQuery struct {
+ config
+ ctx *QueryContext
+ order []identityadoptiondecision.OrderOption
+ inters []Interceptor
+ predicates []predicate.IdentityAdoptionDecision
+ withPendingAuthSession *PendingAuthSessionQuery
+ withIdentity *AuthIdentityQuery
+ modifiers []func(*sql.Selector)
+ // intermediate query (i.e. traversal path).
+ sql *sql.Selector
+ path func(context.Context) (*sql.Selector, error)
+}
+
+// Where adds a new predicate for the IdentityAdoptionDecisionQuery builder.
+func (_q *IdentityAdoptionDecisionQuery) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionQuery {
+ _q.predicates = append(_q.predicates, ps...)
+ return _q
+}
+
+// Limit the number of records to be returned by this query.
+func (_q *IdentityAdoptionDecisionQuery) Limit(limit int) *IdentityAdoptionDecisionQuery {
+ _q.ctx.Limit = &limit
+ return _q
+}
+
+// Offset to start from.
+func (_q *IdentityAdoptionDecisionQuery) Offset(offset int) *IdentityAdoptionDecisionQuery {
+ _q.ctx.Offset = &offset
+ return _q
+}
+
+// Unique configures the query builder to filter duplicate records on query.
+// By default, unique is set to true, and can be disabled using this method.
+func (_q *IdentityAdoptionDecisionQuery) Unique(unique bool) *IdentityAdoptionDecisionQuery {
+ _q.ctx.Unique = &unique
+ return _q
+}
+
+// Order specifies how the records should be ordered.
+func (_q *IdentityAdoptionDecisionQuery) Order(o ...identityadoptiondecision.OrderOption) *IdentityAdoptionDecisionQuery {
+ _q.order = append(_q.order, o...)
+ return _q
+}
+
+// QueryPendingAuthSession chains the current query on the "pending_auth_session" edge.
+func (_q *IdentityAdoptionDecisionQuery) QueryPendingAuthSession() *PendingAuthSessionQuery {
+ query := (&PendingAuthSessionClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(identityadoptiondecision.Table, identityadoptiondecision.FieldID, selector),
+ sqlgraph.To(pendingauthsession.Table, pendingauthsession.FieldID),
+ sqlgraph.Edge(sqlgraph.O2O, true, identityadoptiondecision.PendingAuthSessionTable, identityadoptiondecision.PendingAuthSessionColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// QueryIdentity chains the current query on the "identity" edge.
+func (_q *IdentityAdoptionDecisionQuery) QueryIdentity() *AuthIdentityQuery {
+ query := (&AuthIdentityClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(identityadoptiondecision.Table, identityadoptiondecision.FieldID, selector),
+ sqlgraph.To(authidentity.Table, authidentity.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, identityadoptiondecision.IdentityTable, identityadoptiondecision.IdentityColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// First returns the first IdentityAdoptionDecision entity from the query.
+// Returns a *NotFoundError when no IdentityAdoptionDecision was found.
+func (_q *IdentityAdoptionDecisionQuery) First(ctx context.Context) (*IdentityAdoptionDecision, error) {
+ nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
+ if err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nil, &NotFoundError{identityadoptiondecision.Label}
+ }
+ return nodes[0], nil
+}
+
+// FirstX is like First, but panics if an error occurs.
+func (_q *IdentityAdoptionDecisionQuery) FirstX(ctx context.Context) *IdentityAdoptionDecision {
+ node, err := _q.First(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return node
+}
+
+// FirstID returns the first IdentityAdoptionDecision ID from the query.
+// Returns a *NotFoundError when no IdentityAdoptionDecision ID was found.
+func (_q *IdentityAdoptionDecisionQuery) FirstID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
+ return
+ }
+ if len(ids) == 0 {
+ err = &NotFoundError{identityadoptiondecision.Label}
+ return
+ }
+ return ids[0], nil
+}
+
+// FirstIDX is like FirstID, but panics if an error occurs.
+func (_q *IdentityAdoptionDecisionQuery) FirstIDX(ctx context.Context) int64 {
+ id, err := _q.FirstID(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return id
+}
+
+// Only returns a single IdentityAdoptionDecision entity found by the query, ensuring it only returns one.
+// Returns a *NotSingularError when more than one IdentityAdoptionDecision entity is found.
+// Returns a *NotFoundError when no IdentityAdoptionDecision entities are found.
+func (_q *IdentityAdoptionDecisionQuery) Only(ctx context.Context) (*IdentityAdoptionDecision, error) {
+ nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
+ if err != nil {
+ return nil, err
+ }
+ switch len(nodes) {
+ case 1:
+ return nodes[0], nil
+ case 0:
+ return nil, &NotFoundError{identityadoptiondecision.Label}
+ default:
+ return nil, &NotSingularError{identityadoptiondecision.Label}
+ }
+}
+
+// OnlyX is like Only, but panics if an error occurs.
+func (_q *IdentityAdoptionDecisionQuery) OnlyX(ctx context.Context) *IdentityAdoptionDecision {
+ node, err := _q.Only(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// OnlyID is like Only, but returns the only IdentityAdoptionDecision ID in the query.
+// Returns a *NotSingularError when more than one IdentityAdoptionDecision ID is found.
+// Returns a *NotFoundError when no entities are found.
+func (_q *IdentityAdoptionDecisionQuery) OnlyID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
+ return
+ }
+ switch len(ids) {
+ case 1:
+ id = ids[0]
+ case 0:
+ err = &NotFoundError{identityadoptiondecision.Label}
+ default:
+ err = &NotSingularError{identityadoptiondecision.Label}
+ }
+ return
+}
+
+// OnlyIDX is like OnlyID, but panics if an error occurs.
+func (_q *IdentityAdoptionDecisionQuery) OnlyIDX(ctx context.Context) int64 {
+ id, err := _q.OnlyID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// All executes the query and returns a list of IdentityAdoptionDecisions.
+func (_q *IdentityAdoptionDecisionQuery) All(ctx context.Context) ([]*IdentityAdoptionDecision, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ qr := querierAll[[]*IdentityAdoptionDecision, *IdentityAdoptionDecisionQuery]()
+ return withInterceptors[[]*IdentityAdoptionDecision](ctx, _q, qr, _q.inters)
+}
+
+// AllX is like All, but panics if an error occurs.
+func (_q *IdentityAdoptionDecisionQuery) AllX(ctx context.Context) []*IdentityAdoptionDecision {
+ nodes, err := _q.All(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return nodes
+}
+
+// IDs executes the query and returns a list of IdentityAdoptionDecision IDs.
+func (_q *IdentityAdoptionDecisionQuery) IDs(ctx context.Context) (ids []int64, err error) {
+ if _q.ctx.Unique == nil && _q.path != nil {
+ _q.Unique(true)
+ }
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
+ if err = _q.Select(identityadoptiondecision.FieldID).Scan(ctx, &ids); err != nil {
+ return nil, err
+ }
+ return ids, nil
+}
+
+// IDsX is like IDs, but panics if an error occurs.
+func (_q *IdentityAdoptionDecisionQuery) IDsX(ctx context.Context) []int64 {
+ ids, err := _q.IDs(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return ids
+}
+
+// Count returns the count of the given query.
+func (_q *IdentityAdoptionDecisionQuery) Count(ctx context.Context) (int, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return 0, err
+ }
+ return withInterceptors[int](ctx, _q, querierCount[*IdentityAdoptionDecisionQuery](), _q.inters)
+}
+
+// CountX is like Count, but panics if an error occurs.
+func (_q *IdentityAdoptionDecisionQuery) CountX(ctx context.Context) int {
+ count, err := _q.Count(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return count
+}
+
+// Exist returns true if the query has elements in the graph.
+func (_q *IdentityAdoptionDecisionQuery) Exist(ctx context.Context) (bool, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
+ switch _, err := _q.FirstID(ctx); {
+ case IsNotFound(err):
+ return false, nil
+ case err != nil:
+ return false, fmt.Errorf("ent: check existence: %w", err)
+ default:
+ return true, nil
+ }
+}
+
+// ExistX is like Exist, but panics if an error occurs.
+func (_q *IdentityAdoptionDecisionQuery) ExistX(ctx context.Context) bool {
+ exist, err := _q.Exist(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return exist
+}
+
+// Clone returns a duplicate of the IdentityAdoptionDecisionQuery builder, including all associated steps. It can be
+// used to prepare common query builders and use them differently after the clone is made.
+func (_q *IdentityAdoptionDecisionQuery) Clone() *IdentityAdoptionDecisionQuery {
+ if _q == nil {
+ return nil
+ }
+ return &IdentityAdoptionDecisionQuery{
+ config: _q.config,
+ ctx: _q.ctx.Clone(),
+ order: append([]identityadoptiondecision.OrderOption{}, _q.order...),
+ inters: append([]Interceptor{}, _q.inters...),
+ predicates: append([]predicate.IdentityAdoptionDecision{}, _q.predicates...),
+ withPendingAuthSession: _q.withPendingAuthSession.Clone(),
+ withIdentity: _q.withIdentity.Clone(),
+ // clone intermediate query.
+ sql: _q.sql.Clone(),
+ path: _q.path,
+ }
+}
+
+// WithPendingAuthSession tells the query-builder to eager-load the nodes that are connected to
+// the "pending_auth_session" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *IdentityAdoptionDecisionQuery) WithPendingAuthSession(opts ...func(*PendingAuthSessionQuery)) *IdentityAdoptionDecisionQuery {
+ query := (&PendingAuthSessionClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withPendingAuthSession = query
+ return _q
+}
+
+// WithIdentity tells the query-builder to eager-load the nodes that are connected to
+// the "identity" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *IdentityAdoptionDecisionQuery) WithIdentity(opts ...func(*AuthIdentityQuery)) *IdentityAdoptionDecisionQuery {
+ query := (&AuthIdentityClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withIdentity = query
+ return _q
+}
+
+// GroupBy is used to group vertices by one or more fields/columns.
+// It is often used with aggregate functions, like: count, max, mean, min, sum.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// Count int `json:"count,omitempty"`
+// }
+//
+// client.IdentityAdoptionDecision.Query().
+// GroupBy(identityadoptiondecision.FieldCreatedAt).
+// Aggregate(ent.Count()).
+// Scan(ctx, &v)
+func (_q *IdentityAdoptionDecisionQuery) GroupBy(field string, fields ...string) *IdentityAdoptionDecisionGroupBy {
+ _q.ctx.Fields = append([]string{field}, fields...)
+ grbuild := &IdentityAdoptionDecisionGroupBy{build: _q}
+ grbuild.flds = &_q.ctx.Fields
+ grbuild.label = identityadoptiondecision.Label
+ grbuild.scan = grbuild.Scan
+ return grbuild
+}
+
+// Select allows the selection one or more fields/columns for the given query,
+// instead of selecting all fields in the entity.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// }
+//
+// client.IdentityAdoptionDecision.Query().
+// Select(identityadoptiondecision.FieldCreatedAt).
+// Scan(ctx, &v)
+func (_q *IdentityAdoptionDecisionQuery) Select(fields ...string) *IdentityAdoptionDecisionSelect {
+ _q.ctx.Fields = append(_q.ctx.Fields, fields...)
+ sbuild := &IdentityAdoptionDecisionSelect{IdentityAdoptionDecisionQuery: _q}
+ sbuild.label = identityadoptiondecision.Label
+ sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
+ return sbuild
+}
+
+// Aggregate returns a IdentityAdoptionDecisionSelect configured with the given aggregations.
+func (_q *IdentityAdoptionDecisionQuery) Aggregate(fns ...AggregateFunc) *IdentityAdoptionDecisionSelect {
+ return _q.Select().Aggregate(fns...)
+}
+
+func (_q *IdentityAdoptionDecisionQuery) prepareQuery(ctx context.Context) error {
+ for _, inter := range _q.inters {
+ if inter == nil {
+ return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
+ }
+ if trv, ok := inter.(Traverser); ok {
+ if err := trv.Traverse(ctx, _q); err != nil {
+ return err
+ }
+ }
+ }
+ for _, f := range _q.ctx.Fields {
+ if !identityadoptiondecision.ValidColumn(f) {
+ return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ }
+ if _q.path != nil {
+ prev, err := _q.path(ctx)
+ if err != nil {
+ return err
+ }
+ _q.sql = prev
+ }
+ return nil
+}
+
+func (_q *IdentityAdoptionDecisionQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*IdentityAdoptionDecision, error) {
+ var (
+ nodes = []*IdentityAdoptionDecision{}
+ _spec = _q.querySpec()
+ loadedTypes = [2]bool{
+ _q.withPendingAuthSession != nil,
+ _q.withIdentity != nil,
+ }
+ )
+ _spec.ScanValues = func(columns []string) ([]any, error) {
+ return (*IdentityAdoptionDecision).scanValues(nil, columns)
+ }
+ _spec.Assign = func(columns []string, values []any) error {
+ node := &IdentityAdoptionDecision{config: _q.config}
+ nodes = append(nodes, node)
+ node.Edges.loadedTypes = loadedTypes
+ return node.assignValues(columns, values)
+ }
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ for i := range hooks {
+ hooks[i](ctx, _spec)
+ }
+ if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nodes, nil
+ }
+ if query := _q.withPendingAuthSession; query != nil {
+ if err := _q.loadPendingAuthSession(ctx, query, nodes, nil,
+ func(n *IdentityAdoptionDecision, e *PendingAuthSession) { n.Edges.PendingAuthSession = e }); err != nil {
+ return nil, err
+ }
+ }
+ if query := _q.withIdentity; query != nil {
+ if err := _q.loadIdentity(ctx, query, nodes, nil,
+ func(n *IdentityAdoptionDecision, e *AuthIdentity) { n.Edges.Identity = e }); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+func (_q *IdentityAdoptionDecisionQuery) loadPendingAuthSession(ctx context.Context, query *PendingAuthSessionQuery, nodes []*IdentityAdoptionDecision, init func(*IdentityAdoptionDecision), assign func(*IdentityAdoptionDecision, *PendingAuthSession)) error {
+ ids := make([]int64, 0, len(nodes))
+ nodeids := make(map[int64][]*IdentityAdoptionDecision)
+ for i := range nodes {
+ fk := nodes[i].PendingAuthSessionID
+ if _, ok := nodeids[fk]; !ok {
+ ids = append(ids, fk)
+ }
+ nodeids[fk] = append(nodeids[fk], nodes[i])
+ }
+ if len(ids) == 0 {
+ return nil
+ }
+ query.Where(pendingauthsession.IDIn(ids...))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ nodes, ok := nodeids[n.ID]
+ if !ok {
+ return fmt.Errorf(`unexpected foreign-key "pending_auth_session_id" returned %v`, n.ID)
+ }
+ for i := range nodes {
+ assign(nodes[i], n)
+ }
+ }
+ return nil
+}
+func (_q *IdentityAdoptionDecisionQuery) loadIdentity(ctx context.Context, query *AuthIdentityQuery, nodes []*IdentityAdoptionDecision, init func(*IdentityAdoptionDecision), assign func(*IdentityAdoptionDecision, *AuthIdentity)) error {
+ ids := make([]int64, 0, len(nodes))
+ nodeids := make(map[int64][]*IdentityAdoptionDecision)
+ for i := range nodes {
+ if nodes[i].IdentityID == nil {
+ continue
+ }
+ fk := *nodes[i].IdentityID
+ if _, ok := nodeids[fk]; !ok {
+ ids = append(ids, fk)
+ }
+ nodeids[fk] = append(nodeids[fk], nodes[i])
+ }
+ if len(ids) == 0 {
+ return nil
+ }
+ query.Where(authidentity.IDIn(ids...))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ nodes, ok := nodeids[n.ID]
+ if !ok {
+ return fmt.Errorf(`unexpected foreign-key "identity_id" returned %v`, n.ID)
+ }
+ for i := range nodes {
+ assign(nodes[i], n)
+ }
+ }
+ return nil
+}
+
+func (_q *IdentityAdoptionDecisionQuery) sqlCount(ctx context.Context) (int, error) {
+ _spec := _q.querySpec()
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ _spec.Node.Columns = _q.ctx.Fields
+ if len(_q.ctx.Fields) > 0 {
+ _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
+ }
+ return sqlgraph.CountNodes(ctx, _q.driver, _spec)
+}
+
+func (_q *IdentityAdoptionDecisionQuery) querySpec() *sqlgraph.QuerySpec {
+ _spec := sqlgraph.NewQuerySpec(identityadoptiondecision.Table, identityadoptiondecision.Columns, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64))
+ _spec.From = _q.sql
+ if unique := _q.ctx.Unique; unique != nil {
+ _spec.Unique = *unique
+ } else if _q.path != nil {
+ _spec.Unique = true
+ }
+ if fields := _q.ctx.Fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, identityadoptiondecision.FieldID)
+ for i := range fields {
+ if fields[i] != identityadoptiondecision.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, fields[i])
+ }
+ }
+ if _q.withPendingAuthSession != nil {
+ _spec.Node.AddColumnOnce(identityadoptiondecision.FieldPendingAuthSessionID)
+ }
+ if _q.withIdentity != nil {
+ _spec.Node.AddColumnOnce(identityadoptiondecision.FieldIdentityID)
+ }
+ }
+ if ps := _q.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ _spec.Limit = *limit
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ _spec.Offset = *offset
+ }
+ if ps := _q.order; len(ps) > 0 {
+ _spec.Order = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ return _spec
+}
+
+func (_q *IdentityAdoptionDecisionQuery) sqlQuery(ctx context.Context) *sql.Selector {
+ builder := sql.Dialect(_q.driver.Dialect())
+ t1 := builder.Table(identityadoptiondecision.Table)
+ columns := _q.ctx.Fields
+ if len(columns) == 0 {
+ columns = identityadoptiondecision.Columns
+ }
+ selector := builder.Select(t1.Columns(columns...)...).From(t1)
+ if _q.sql != nil {
+ selector = _q.sql
+ selector.Select(selector.Columns(columns...)...)
+ }
+ if _q.ctx.Unique != nil && *_q.ctx.Unique {
+ selector.Distinct()
+ }
+ for _, m := range _q.modifiers {
+ m(selector)
+ }
+ for _, p := range _q.predicates {
+ p(selector)
+ }
+ for _, p := range _q.order {
+ p(selector)
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ // limit is mandatory for offset clause. We start
+ // with default value, and override it below if needed.
+ selector.Offset(*offset).Limit(math.MaxInt32)
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ selector.Limit(*limit)
+ }
+ return selector
+}
+
+// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
+// updated, deleted or "selected ... for update" by other sessions, until the transaction is
+// either committed or rolled-back.
+func (_q *IdentityAdoptionDecisionQuery) ForUpdate(opts ...sql.LockOption) *IdentityAdoptionDecisionQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForUpdate(opts...)
+ })
+ return _q
+}
+
+// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
+// on any rows that are read. Other sessions can read the rows, but cannot modify them
+// until your transaction commits.
+func (_q *IdentityAdoptionDecisionQuery) ForShare(opts ...sql.LockOption) *IdentityAdoptionDecisionQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForShare(opts...)
+ })
+ return _q
+}
+
+// IdentityAdoptionDecisionGroupBy is the group-by builder for IdentityAdoptionDecision entities.
+type IdentityAdoptionDecisionGroupBy struct {
+ selector
+ build *IdentityAdoptionDecisionQuery
+}
+
+// Aggregate adds the given aggregation functions to the group-by query.
+func (_g *IdentityAdoptionDecisionGroupBy) Aggregate(fns ...AggregateFunc) *IdentityAdoptionDecisionGroupBy {
+ _g.fns = append(_g.fns, fns...)
+ return _g
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_g *IdentityAdoptionDecisionGroupBy) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
+ if err := _g.build.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*IdentityAdoptionDecisionQuery, *IdentityAdoptionDecisionGroupBy](ctx, _g.build, _g, _g.build.inters, v)
+}
+
+func (_g *IdentityAdoptionDecisionGroupBy) sqlScan(ctx context.Context, root *IdentityAdoptionDecisionQuery, v any) error {
+ selector := root.sqlQuery(ctx).Select()
+ aggregation := make([]string, 0, len(_g.fns))
+ for _, fn := range _g.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ if len(selector.SelectedColumns()) == 0 {
+ columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
+ for _, f := range *_g.flds {
+ columns = append(columns, selector.C(f))
+ }
+ columns = append(columns, aggregation...)
+ selector.Select(columns...)
+ }
+ selector.GroupBy(selector.Columns(*_g.flds...)...)
+ if err := selector.Err(); err != nil {
+ return err
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
+
+// IdentityAdoptionDecisionSelect is the builder for selecting fields of IdentityAdoptionDecision entities.
+type IdentityAdoptionDecisionSelect struct {
+ *IdentityAdoptionDecisionQuery
+ selector
+}
+
+// Aggregate adds the given aggregation functions to the selector query.
+func (_s *IdentityAdoptionDecisionSelect) Aggregate(fns ...AggregateFunc) *IdentityAdoptionDecisionSelect {
+ _s.fns = append(_s.fns, fns...)
+ return _s
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_s *IdentityAdoptionDecisionSelect) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
+ if err := _s.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*IdentityAdoptionDecisionQuery, *IdentityAdoptionDecisionSelect](ctx, _s.IdentityAdoptionDecisionQuery, _s, _s.inters, v)
+}
+
+func (_s *IdentityAdoptionDecisionSelect) sqlScan(ctx context.Context, root *IdentityAdoptionDecisionQuery, v any) error {
+ selector := root.sqlQuery(ctx)
+ aggregation := make([]string, 0, len(_s.fns))
+ for _, fn := range _s.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ switch n := len(*_s.selector.flds); {
+ case n == 0 && len(aggregation) > 0:
+ selector.Select(aggregation...)
+ case n != 0 && len(aggregation) > 0:
+ selector.AppendSelect(aggregation...)
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _s.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
diff --git a/backend/ent/identityadoptiondecision_update.go b/backend/ent/identityadoptiondecision_update.go
new file mode 100644
index 00000000..0ca21d27
--- /dev/null
+++ b/backend/ent/identityadoptiondecision_update.go
@@ -0,0 +1,532 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// IdentityAdoptionDecisionUpdate is the builder for updating IdentityAdoptionDecision entities.
+type IdentityAdoptionDecisionUpdate struct {
+ config
+ hooks []Hook
+ mutation *IdentityAdoptionDecisionMutation
+}
+
+// Where appends a list predicates to the IdentityAdoptionDecisionUpdate builder.
+func (_u *IdentityAdoptionDecisionUpdate) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionUpdate {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *IdentityAdoptionDecisionUpdate) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpdate {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetPendingAuthSessionID sets the "pending_auth_session_id" field.
+func (_u *IdentityAdoptionDecisionUpdate) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpdate {
+ _u.mutation.SetPendingAuthSessionID(v)
+ return _u
+}
+
+// SetNillablePendingAuthSessionID sets the "pending_auth_session_id" field if the given value is not nil.
+func (_u *IdentityAdoptionDecisionUpdate) SetNillablePendingAuthSessionID(v *int64) *IdentityAdoptionDecisionUpdate {
+ if v != nil {
+ _u.SetPendingAuthSessionID(*v)
+ }
+ return _u
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (_u *IdentityAdoptionDecisionUpdate) SetIdentityID(v int64) *IdentityAdoptionDecisionUpdate {
+ _u.mutation.SetIdentityID(v)
+ return _u
+}
+
+// SetNillableIdentityID sets the "identity_id" field if the given value is not nil.
+func (_u *IdentityAdoptionDecisionUpdate) SetNillableIdentityID(v *int64) *IdentityAdoptionDecisionUpdate {
+ if v != nil {
+ _u.SetIdentityID(*v)
+ }
+ return _u
+}
+
+// ClearIdentityID clears the value of the "identity_id" field.
+func (_u *IdentityAdoptionDecisionUpdate) ClearIdentityID() *IdentityAdoptionDecisionUpdate {
+ _u.mutation.ClearIdentityID()
+ return _u
+}
+
+// SetAdoptDisplayName sets the "adopt_display_name" field.
+func (_u *IdentityAdoptionDecisionUpdate) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpdate {
+ _u.mutation.SetAdoptDisplayName(v)
+ return _u
+}
+
+// SetNillableAdoptDisplayName sets the "adopt_display_name" field if the given value is not nil.
+func (_u *IdentityAdoptionDecisionUpdate) SetNillableAdoptDisplayName(v *bool) *IdentityAdoptionDecisionUpdate {
+ if v != nil {
+ _u.SetAdoptDisplayName(*v)
+ }
+ return _u
+}
+
+// SetAdoptAvatar sets the "adopt_avatar" field.
+func (_u *IdentityAdoptionDecisionUpdate) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpdate {
+ _u.mutation.SetAdoptAvatar(v)
+ return _u
+}
+
+// SetNillableAdoptAvatar sets the "adopt_avatar" field if the given value is not nil.
+func (_u *IdentityAdoptionDecisionUpdate) SetNillableAdoptAvatar(v *bool) *IdentityAdoptionDecisionUpdate {
+ if v != nil {
+ _u.SetAdoptAvatar(*v)
+ }
+ return _u
+}
+
+// SetPendingAuthSession sets the "pending_auth_session" edge to the PendingAuthSession entity.
+func (_u *IdentityAdoptionDecisionUpdate) SetPendingAuthSession(v *PendingAuthSession) *IdentityAdoptionDecisionUpdate {
+ return _u.SetPendingAuthSessionID(v.ID)
+}
+
+// SetIdentity sets the "identity" edge to the AuthIdentity entity.
+func (_u *IdentityAdoptionDecisionUpdate) SetIdentity(v *AuthIdentity) *IdentityAdoptionDecisionUpdate {
+ return _u.SetIdentityID(v.ID)
+}
+
+// Mutation returns the IdentityAdoptionDecisionMutation object of the builder.
+func (_u *IdentityAdoptionDecisionUpdate) Mutation() *IdentityAdoptionDecisionMutation {
+ return _u.mutation
+}
+
+// ClearPendingAuthSession clears the "pending_auth_session" edge to the PendingAuthSession entity.
+func (_u *IdentityAdoptionDecisionUpdate) ClearPendingAuthSession() *IdentityAdoptionDecisionUpdate {
+ _u.mutation.ClearPendingAuthSession()
+ return _u
+}
+
+// ClearIdentity clears the "identity" edge to the AuthIdentity entity.
+func (_u *IdentityAdoptionDecisionUpdate) ClearIdentity() *IdentityAdoptionDecisionUpdate {
+ _u.mutation.ClearIdentity()
+ return _u
+}
+
+// Save executes the query and returns the number of nodes affected by the update operation.
+func (_u *IdentityAdoptionDecisionUpdate) Save(ctx context.Context) (int, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *IdentityAdoptionDecisionUpdate) SaveX(ctx context.Context) int {
+ affected, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return affected
+}
+
+// Exec executes the query.
+func (_u *IdentityAdoptionDecisionUpdate) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *IdentityAdoptionDecisionUpdate) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *IdentityAdoptionDecisionUpdate) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := identityadoptiondecision.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *IdentityAdoptionDecisionUpdate) check() error {
+ if _u.mutation.PendingAuthSessionCleared() && len(_u.mutation.PendingAuthSessionIDs()) > 0 {
+ return errors.New(`ent: clearing a required unique edge "IdentityAdoptionDecision.pending_auth_session"`)
+ }
+ return nil
+}
+
+func (_u *IdentityAdoptionDecisionUpdate) sqlSave(ctx context.Context) (_node int, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(identityadoptiondecision.Table, identityadoptiondecision.Columns, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64))
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(identityadoptiondecision.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.AdoptDisplayName(); ok {
+ _spec.SetField(identityadoptiondecision.FieldAdoptDisplayName, field.TypeBool, value)
+ }
+ if value, ok := _u.mutation.AdoptAvatar(); ok {
+ _spec.SetField(identityadoptiondecision.FieldAdoptAvatar, field.TypeBool, value)
+ }
+ if _u.mutation.PendingAuthSessionCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: true,
+ Table: identityadoptiondecision.PendingAuthSessionTable,
+ Columns: []string{identityadoptiondecision.PendingAuthSessionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.PendingAuthSessionIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: true,
+ Table: identityadoptiondecision.PendingAuthSessionTable,
+ Columns: []string{identityadoptiondecision.PendingAuthSessionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.IdentityCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: identityadoptiondecision.IdentityTable,
+ Columns: []string{identityadoptiondecision.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.IdentityIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: identityadoptiondecision.IdentityTable,
+ Columns: []string{identityadoptiondecision.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{identityadoptiondecision.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return 0, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
+
+// IdentityAdoptionDecisionUpdateOne is the builder for updating a single IdentityAdoptionDecision entity.
+type IdentityAdoptionDecisionUpdateOne struct {
+ config
+ fields []string
+ hooks []Hook
+ mutation *IdentityAdoptionDecisionMutation
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpdateOne {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetPendingAuthSessionID sets the "pending_auth_session_id" field.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpdateOne {
+ _u.mutation.SetPendingAuthSessionID(v)
+ return _u
+}
+
+// SetNillablePendingAuthSessionID sets the "pending_auth_session_id" field if the given value is not nil.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetNillablePendingAuthSessionID(v *int64) *IdentityAdoptionDecisionUpdateOne {
+ if v != nil {
+ _u.SetPendingAuthSessionID(*v)
+ }
+ return _u
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetIdentityID(v int64) *IdentityAdoptionDecisionUpdateOne {
+ _u.mutation.SetIdentityID(v)
+ return _u
+}
+
+// SetNillableIdentityID sets the "identity_id" field if the given value is not nil.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetNillableIdentityID(v *int64) *IdentityAdoptionDecisionUpdateOne {
+ if v != nil {
+ _u.SetIdentityID(*v)
+ }
+ return _u
+}
+
+// ClearIdentityID clears the value of the "identity_id" field.
+func (_u *IdentityAdoptionDecisionUpdateOne) ClearIdentityID() *IdentityAdoptionDecisionUpdateOne {
+ _u.mutation.ClearIdentityID()
+ return _u
+}
+
+// SetAdoptDisplayName sets the "adopt_display_name" field.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpdateOne {
+ _u.mutation.SetAdoptDisplayName(v)
+ return _u
+}
+
+// SetNillableAdoptDisplayName sets the "adopt_display_name" field if the given value is not nil.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetNillableAdoptDisplayName(v *bool) *IdentityAdoptionDecisionUpdateOne {
+ if v != nil {
+ _u.SetAdoptDisplayName(*v)
+ }
+ return _u
+}
+
+// SetAdoptAvatar sets the "adopt_avatar" field.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpdateOne {
+ _u.mutation.SetAdoptAvatar(v)
+ return _u
+}
+
+// SetNillableAdoptAvatar sets the "adopt_avatar" field if the given value is not nil.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetNillableAdoptAvatar(v *bool) *IdentityAdoptionDecisionUpdateOne {
+ if v != nil {
+ _u.SetAdoptAvatar(*v)
+ }
+ return _u
+}
+
+// SetPendingAuthSession sets the "pending_auth_session" edge to the PendingAuthSession entity.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetPendingAuthSession(v *PendingAuthSession) *IdentityAdoptionDecisionUpdateOne {
+ return _u.SetPendingAuthSessionID(v.ID)
+}
+
+// SetIdentity sets the "identity" edge to the AuthIdentity entity.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetIdentity(v *AuthIdentity) *IdentityAdoptionDecisionUpdateOne {
+ return _u.SetIdentityID(v.ID)
+}
+
+// Mutation returns the IdentityAdoptionDecisionMutation object of the builder.
+func (_u *IdentityAdoptionDecisionUpdateOne) Mutation() *IdentityAdoptionDecisionMutation {
+ return _u.mutation
+}
+
+// ClearPendingAuthSession clears the "pending_auth_session" edge to the PendingAuthSession entity.
+func (_u *IdentityAdoptionDecisionUpdateOne) ClearPendingAuthSession() *IdentityAdoptionDecisionUpdateOne {
+ _u.mutation.ClearPendingAuthSession()
+ return _u
+}
+
+// ClearIdentity clears the "identity" edge to the AuthIdentity entity.
+func (_u *IdentityAdoptionDecisionUpdateOne) ClearIdentity() *IdentityAdoptionDecisionUpdateOne {
+ _u.mutation.ClearIdentity()
+ return _u
+}
+
+// Where appends a list predicates to the IdentityAdoptionDecisionUpdate builder.
+func (_u *IdentityAdoptionDecisionUpdateOne) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionUpdateOne {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// Select allows selecting one or more fields (columns) of the returned entity.
+// The default is selecting all fields defined in the entity schema.
+func (_u *IdentityAdoptionDecisionUpdateOne) Select(field string, fields ...string) *IdentityAdoptionDecisionUpdateOne {
+ _u.fields = append([]string{field}, fields...)
+ return _u
+}
+
+// Save executes the query and returns the updated IdentityAdoptionDecision entity.
+func (_u *IdentityAdoptionDecisionUpdateOne) Save(ctx context.Context) (*IdentityAdoptionDecision, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *IdentityAdoptionDecisionUpdateOne) SaveX(ctx context.Context) *IdentityAdoptionDecision {
+ node, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// Exec executes the query on the entity.
+func (_u *IdentityAdoptionDecisionUpdateOne) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *IdentityAdoptionDecisionUpdateOne) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *IdentityAdoptionDecisionUpdateOne) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := identityadoptiondecision.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *IdentityAdoptionDecisionUpdateOne) check() error {
+ if _u.mutation.PendingAuthSessionCleared() && len(_u.mutation.PendingAuthSessionIDs()) > 0 {
+ return errors.New(`ent: clearing a required unique edge "IdentityAdoptionDecision.pending_auth_session"`)
+ }
+ return nil
+}
+
+func (_u *IdentityAdoptionDecisionUpdateOne) sqlSave(ctx context.Context) (_node *IdentityAdoptionDecision, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(identityadoptiondecision.Table, identityadoptiondecision.Columns, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64))
+ id, ok := _u.mutation.ID()
+ if !ok {
+ return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "IdentityAdoptionDecision.id" for update`)}
+ }
+ _spec.Node.ID.Value = id
+ if fields := _u.fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, identityadoptiondecision.FieldID)
+ for _, f := range fields {
+ if !identityadoptiondecision.ValidColumn(f) {
+ return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ if f != identityadoptiondecision.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, f)
+ }
+ }
+ }
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(identityadoptiondecision.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.AdoptDisplayName(); ok {
+ _spec.SetField(identityadoptiondecision.FieldAdoptDisplayName, field.TypeBool, value)
+ }
+ if value, ok := _u.mutation.AdoptAvatar(); ok {
+ _spec.SetField(identityadoptiondecision.FieldAdoptAvatar, field.TypeBool, value)
+ }
+ if _u.mutation.PendingAuthSessionCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: true,
+ Table: identityadoptiondecision.PendingAuthSessionTable,
+ Columns: []string{identityadoptiondecision.PendingAuthSessionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.PendingAuthSessionIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: true,
+ Table: identityadoptiondecision.PendingAuthSessionTable,
+ Columns: []string{identityadoptiondecision.PendingAuthSessionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.IdentityCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: identityadoptiondecision.IdentityTable,
+ Columns: []string{identityadoptiondecision.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.IdentityIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: identityadoptiondecision.IdentityTable,
+ Columns: []string{identityadoptiondecision.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ _node = &IdentityAdoptionDecision{config: _u.config}
+ _spec.Assign = _node.assignValues
+ _spec.ScanValues = _node.scanValues
+ if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{identityadoptiondecision.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
diff --git a/backend/ent/intercept/intercept.go b/backend/ent/intercept/intercept.go
index 8d8320bb..157c5122 100644
--- a/backend/ent/intercept/intercept.go
+++ b/backend/ent/intercept/intercept.go
@@ -13,12 +13,16 @@ import (
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
@@ -228,6 +232,60 @@ func (f TraverseAnnouncementRead) Traverse(ctx context.Context, q ent.Query) err
return fmt.Errorf("unexpected query type %T. expect *ent.AnnouncementReadQuery", q)
}
+// The AuthIdentityFunc type is an adapter to allow the use of ordinary function as a Querier.
+type AuthIdentityFunc func(context.Context, *ent.AuthIdentityQuery) (ent.Value, error)
+
+// Query calls f(ctx, q).
+func (f AuthIdentityFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
+ if q, ok := q.(*ent.AuthIdentityQuery); ok {
+ return f(ctx, q)
+ }
+ return nil, fmt.Errorf("unexpected query type %T. expect *ent.AuthIdentityQuery", q)
+}
+
+// The TraverseAuthIdentity type is an adapter to allow the use of ordinary function as Traverser.
+type TraverseAuthIdentity func(context.Context, *ent.AuthIdentityQuery) error
+
+// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
+func (f TraverseAuthIdentity) Intercept(next ent.Querier) ent.Querier {
+ return next
+}
+
+// Traverse calls f(ctx, q).
+func (f TraverseAuthIdentity) Traverse(ctx context.Context, q ent.Query) error {
+ if q, ok := q.(*ent.AuthIdentityQuery); ok {
+ return f(ctx, q)
+ }
+ return fmt.Errorf("unexpected query type %T. expect *ent.AuthIdentityQuery", q)
+}
+
+// The AuthIdentityChannelFunc type is an adapter to allow the use of ordinary function as a Querier.
+type AuthIdentityChannelFunc func(context.Context, *ent.AuthIdentityChannelQuery) (ent.Value, error)
+
+// Query calls f(ctx, q).
+func (f AuthIdentityChannelFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
+ if q, ok := q.(*ent.AuthIdentityChannelQuery); ok {
+ return f(ctx, q)
+ }
+ return nil, fmt.Errorf("unexpected query type %T. expect *ent.AuthIdentityChannelQuery", q)
+}
+
+// The TraverseAuthIdentityChannel type is an adapter to allow the use of ordinary function as Traverser.
+type TraverseAuthIdentityChannel func(context.Context, *ent.AuthIdentityChannelQuery) error
+
+// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
+func (f TraverseAuthIdentityChannel) Intercept(next ent.Querier) ent.Querier {
+ return next
+}
+
+// Traverse calls f(ctx, q).
+func (f TraverseAuthIdentityChannel) Traverse(ctx context.Context, q ent.Query) error {
+ if q, ok := q.(*ent.AuthIdentityChannelQuery); ok {
+ return f(ctx, q)
+ }
+ return fmt.Errorf("unexpected query type %T. expect *ent.AuthIdentityChannelQuery", q)
+}
+
// The ErrorPassthroughRuleFunc type is an adapter to allow the use of ordinary function as a Querier.
type ErrorPassthroughRuleFunc func(context.Context, *ent.ErrorPassthroughRuleQuery) (ent.Value, error)
@@ -309,6 +367,33 @@ func (f TraverseIdempotencyRecord) Traverse(ctx context.Context, q ent.Query) er
return fmt.Errorf("unexpected query type %T. expect *ent.IdempotencyRecordQuery", q)
}
+// The IdentityAdoptionDecisionFunc type is an adapter to allow the use of ordinary function as a Querier.
+type IdentityAdoptionDecisionFunc func(context.Context, *ent.IdentityAdoptionDecisionQuery) (ent.Value, error)
+
+// Query calls f(ctx, q).
+func (f IdentityAdoptionDecisionFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
+ if q, ok := q.(*ent.IdentityAdoptionDecisionQuery); ok {
+ return f(ctx, q)
+ }
+ return nil, fmt.Errorf("unexpected query type %T. expect *ent.IdentityAdoptionDecisionQuery", q)
+}
+
+// The TraverseIdentityAdoptionDecision type is an adapter to allow the use of ordinary function as Traverser.
+type TraverseIdentityAdoptionDecision func(context.Context, *ent.IdentityAdoptionDecisionQuery) error
+
+// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
+func (f TraverseIdentityAdoptionDecision) Intercept(next ent.Querier) ent.Querier {
+ return next
+}
+
+// Traverse calls f(ctx, q).
+func (f TraverseIdentityAdoptionDecision) Traverse(ctx context.Context, q ent.Query) error {
+ if q, ok := q.(*ent.IdentityAdoptionDecisionQuery); ok {
+ return f(ctx, q)
+ }
+ return fmt.Errorf("unexpected query type %T. expect *ent.IdentityAdoptionDecisionQuery", q)
+}
+
// The PaymentAuditLogFunc type is an adapter to allow the use of ordinary function as a Querier.
type PaymentAuditLogFunc func(context.Context, *ent.PaymentAuditLogQuery) (ent.Value, error)
@@ -390,6 +475,33 @@ func (f TraversePaymentProviderInstance) Traverse(ctx context.Context, q ent.Que
return fmt.Errorf("unexpected query type %T. expect *ent.PaymentProviderInstanceQuery", q)
}
+// The PendingAuthSessionFunc type is an adapter to allow the use of ordinary function as a Querier.
+type PendingAuthSessionFunc func(context.Context, *ent.PendingAuthSessionQuery) (ent.Value, error)
+
+// Query calls f(ctx, q).
+func (f PendingAuthSessionFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
+ if q, ok := q.(*ent.PendingAuthSessionQuery); ok {
+ return f(ctx, q)
+ }
+ return nil, fmt.Errorf("unexpected query type %T. expect *ent.PendingAuthSessionQuery", q)
+}
+
+// The TraversePendingAuthSession type is an adapter to allow the use of ordinary function as Traverser.
+type TraversePendingAuthSession func(context.Context, *ent.PendingAuthSessionQuery) error
+
+// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
+func (f TraversePendingAuthSession) Intercept(next ent.Querier) ent.Querier {
+ return next
+}
+
+// Traverse calls f(ctx, q).
+func (f TraversePendingAuthSession) Traverse(ctx context.Context, q ent.Query) error {
+ if q, ok := q.(*ent.PendingAuthSessionQuery); ok {
+ return f(ctx, q)
+ }
+ return fmt.Errorf("unexpected query type %T. expect *ent.PendingAuthSessionQuery", q)
+}
+
// The PromoCodeFunc type is an adapter to allow the use of ordinary function as a Querier.
type PromoCodeFunc func(context.Context, *ent.PromoCodeQuery) (ent.Value, error)
@@ -808,18 +920,26 @@ func NewQuery(q ent.Query) (Query, error) {
return &query[*ent.AnnouncementQuery, predicate.Announcement, announcement.OrderOption]{typ: ent.TypeAnnouncement, tq: q}, nil
case *ent.AnnouncementReadQuery:
return &query[*ent.AnnouncementReadQuery, predicate.AnnouncementRead, announcementread.OrderOption]{typ: ent.TypeAnnouncementRead, tq: q}, nil
+ case *ent.AuthIdentityQuery:
+ return &query[*ent.AuthIdentityQuery, predicate.AuthIdentity, authidentity.OrderOption]{typ: ent.TypeAuthIdentity, tq: q}, nil
+ case *ent.AuthIdentityChannelQuery:
+ return &query[*ent.AuthIdentityChannelQuery, predicate.AuthIdentityChannel, authidentitychannel.OrderOption]{typ: ent.TypeAuthIdentityChannel, tq: q}, nil
case *ent.ErrorPassthroughRuleQuery:
return &query[*ent.ErrorPassthroughRuleQuery, predicate.ErrorPassthroughRule, errorpassthroughrule.OrderOption]{typ: ent.TypeErrorPassthroughRule, tq: q}, nil
case *ent.GroupQuery:
return &query[*ent.GroupQuery, predicate.Group, group.OrderOption]{typ: ent.TypeGroup, tq: q}, nil
case *ent.IdempotencyRecordQuery:
return &query[*ent.IdempotencyRecordQuery, predicate.IdempotencyRecord, idempotencyrecord.OrderOption]{typ: ent.TypeIdempotencyRecord, tq: q}, nil
+ case *ent.IdentityAdoptionDecisionQuery:
+ return &query[*ent.IdentityAdoptionDecisionQuery, predicate.IdentityAdoptionDecision, identityadoptiondecision.OrderOption]{typ: ent.TypeIdentityAdoptionDecision, tq: q}, nil
case *ent.PaymentAuditLogQuery:
return &query[*ent.PaymentAuditLogQuery, predicate.PaymentAuditLog, paymentauditlog.OrderOption]{typ: ent.TypePaymentAuditLog, tq: q}, nil
case *ent.PaymentOrderQuery:
return &query[*ent.PaymentOrderQuery, predicate.PaymentOrder, paymentorder.OrderOption]{typ: ent.TypePaymentOrder, tq: q}, nil
case *ent.PaymentProviderInstanceQuery:
return &query[*ent.PaymentProviderInstanceQuery, predicate.PaymentProviderInstance, paymentproviderinstance.OrderOption]{typ: ent.TypePaymentProviderInstance, tq: q}, nil
+ case *ent.PendingAuthSessionQuery:
+ return &query[*ent.PendingAuthSessionQuery, predicate.PendingAuthSession, pendingauthsession.OrderOption]{typ: ent.TypePendingAuthSession, tq: q}, nil
case *ent.PromoCodeQuery:
return &query[*ent.PromoCodeQuery, predicate.PromoCode, promocode.OrderOption]{typ: ent.TypePromoCode, tq: q}, nil
case *ent.PromoCodeUsageQuery:
diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go
index 68bdbf55..bf41e73b 100644
--- a/backend/ent/migrate/schema.go
+++ b/backend/ent/migrate/schema.go
@@ -338,6 +338,89 @@ var (
},
},
}
+ // AuthIdentitiesColumns holds the columns for the "auth_identities" table.
+ AuthIdentitiesColumns = []*schema.Column{
+ {Name: "id", Type: field.TypeInt64, Increment: true},
+ {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "provider_type", Type: field.TypeString, Size: 20},
+ {Name: "provider_key", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "provider_subject", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "verified_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "issuer", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "metadata", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
+ {Name: "user_id", Type: field.TypeInt64},
+ }
+ // AuthIdentitiesTable holds the schema information for the "auth_identities" table.
+ AuthIdentitiesTable = &schema.Table{
+ Name: "auth_identities",
+ Columns: AuthIdentitiesColumns,
+ PrimaryKey: []*schema.Column{AuthIdentitiesColumns[0]},
+ ForeignKeys: []*schema.ForeignKey{
+ {
+ Symbol: "auth_identities_users_auth_identities",
+ Columns: []*schema.Column{AuthIdentitiesColumns[9]},
+ RefColumns: []*schema.Column{UsersColumns[0]},
+ OnDelete: schema.NoAction,
+ },
+ },
+ Indexes: []*schema.Index{
+ {
+ Name: "authidentity_provider_type_provider_key_provider_subject",
+ Unique: true,
+ Columns: []*schema.Column{AuthIdentitiesColumns[3], AuthIdentitiesColumns[4], AuthIdentitiesColumns[5]},
+ },
+ {
+ Name: "authidentity_user_id",
+ Unique: false,
+ Columns: []*schema.Column{AuthIdentitiesColumns[9]},
+ },
+ {
+ Name: "authidentity_user_id_provider_type",
+ Unique: false,
+ Columns: []*schema.Column{AuthIdentitiesColumns[9], AuthIdentitiesColumns[3]},
+ },
+ },
+ }
+ // AuthIdentityChannelsColumns holds the columns for the "auth_identity_channels" table.
+ AuthIdentityChannelsColumns = []*schema.Column{
+ {Name: "id", Type: field.TypeInt64, Increment: true},
+ {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "provider_type", Type: field.TypeString, Size: 20},
+ {Name: "provider_key", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "channel", Type: field.TypeString, Size: 20},
+ {Name: "channel_app_id", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "channel_subject", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "metadata", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
+ {Name: "identity_id", Type: field.TypeInt64},
+ }
+ // AuthIdentityChannelsTable holds the schema information for the "auth_identity_channels" table.
+ AuthIdentityChannelsTable = &schema.Table{
+ Name: "auth_identity_channels",
+ Columns: AuthIdentityChannelsColumns,
+ PrimaryKey: []*schema.Column{AuthIdentityChannelsColumns[0]},
+ ForeignKeys: []*schema.ForeignKey{
+ {
+ Symbol: "auth_identity_channels_auth_identities_channels",
+ Columns: []*schema.Column{AuthIdentityChannelsColumns[9]},
+ RefColumns: []*schema.Column{AuthIdentitiesColumns[0]},
+ OnDelete: schema.NoAction,
+ },
+ },
+ Indexes: []*schema.Index{
+ {
+ Name: "authidentitychannel_provider_type_provider_key_channel_channel_app_id_channel_subject",
+ Unique: true,
+ Columns: []*schema.Column{AuthIdentityChannelsColumns[3], AuthIdentityChannelsColumns[4], AuthIdentityChannelsColumns[5], AuthIdentityChannelsColumns[6], AuthIdentityChannelsColumns[7]},
+ },
+ {
+ Name: "authidentitychannel_identity_id",
+ Unique: false,
+ Columns: []*schema.Column{AuthIdentityChannelsColumns[9]},
+ },
+ },
+ }
// ErrorPassthroughRulesColumns holds the columns for the "error_passthrough_rules" table.
ErrorPassthroughRulesColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
@@ -485,6 +568,49 @@ var (
},
},
}
+ // IdentityAdoptionDecisionsColumns holds the columns for the "identity_adoption_decisions" table.
+ IdentityAdoptionDecisionsColumns = []*schema.Column{
+ {Name: "id", Type: field.TypeInt64, Increment: true},
+ {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "adopt_display_name", Type: field.TypeBool, Default: false},
+ {Name: "adopt_avatar", Type: field.TypeBool, Default: false},
+ {Name: "decided_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "identity_id", Type: field.TypeInt64, Nullable: true},
+ {Name: "pending_auth_session_id", Type: field.TypeInt64, Unique: true},
+ }
+ // IdentityAdoptionDecisionsTable holds the schema information for the "identity_adoption_decisions" table.
+ IdentityAdoptionDecisionsTable = &schema.Table{
+ Name: "identity_adoption_decisions",
+ Columns: IdentityAdoptionDecisionsColumns,
+ PrimaryKey: []*schema.Column{IdentityAdoptionDecisionsColumns[0]},
+ ForeignKeys: []*schema.ForeignKey{
+ {
+ Symbol: "identity_adoption_decisions_auth_identities_adoption_decisions",
+ Columns: []*schema.Column{IdentityAdoptionDecisionsColumns[6]},
+ RefColumns: []*schema.Column{AuthIdentitiesColumns[0]},
+ OnDelete: schema.SetNull,
+ },
+ {
+ Symbol: "identity_adoption_decisions_pending_auth_sessions_adoption_decision",
+ Columns: []*schema.Column{IdentityAdoptionDecisionsColumns[7]},
+ RefColumns: []*schema.Column{PendingAuthSessionsColumns[0]},
+ OnDelete: schema.NoAction,
+ },
+ },
+ Indexes: []*schema.Index{
+ {
+ Name: "identityadoptiondecision_pending_auth_session_id",
+ Unique: true,
+ Columns: []*schema.Column{IdentityAdoptionDecisionsColumns[7]},
+ },
+ {
+ Name: "identityadoptiondecision_identity_id",
+ Unique: false,
+ Columns: []*schema.Column{IdentityAdoptionDecisionsColumns[6]},
+ },
+ },
+ }
// PaymentAuditLogsColumns holds the columns for the "payment_audit_logs" table.
PaymentAuditLogsColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
@@ -638,6 +764,72 @@ var (
},
},
}
+ // PendingAuthSessionsColumns holds the columns for the "pending_auth_sessions" table.
+ PendingAuthSessionsColumns = []*schema.Column{
+ {Name: "id", Type: field.TypeInt64, Increment: true},
+ {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "session_token", Type: field.TypeString, Size: 255},
+ {Name: "intent", Type: field.TypeString, Size: 40},
+ {Name: "provider_type", Type: field.TypeString, Size: 20},
+ {Name: "provider_key", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "provider_subject", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "redirect_to", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "resolved_email", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "registration_password_hash", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "upstream_identity_claims", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
+ {Name: "local_flow_state", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
+ {Name: "browser_session_key", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "completion_code_hash", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "completion_code_expires_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "email_verified_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "password_verified_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "totp_verified_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "expires_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "consumed_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "target_user_id", Type: field.TypeInt64, Nullable: true},
+ }
+ // PendingAuthSessionsTable holds the schema information for the "pending_auth_sessions" table.
+ PendingAuthSessionsTable = &schema.Table{
+ Name: "pending_auth_sessions",
+ Columns: PendingAuthSessionsColumns,
+ PrimaryKey: []*schema.Column{PendingAuthSessionsColumns[0]},
+ ForeignKeys: []*schema.ForeignKey{
+ {
+ Symbol: "pending_auth_sessions_users_pending_auth_sessions",
+ Columns: []*schema.Column{PendingAuthSessionsColumns[21]},
+ RefColumns: []*schema.Column{UsersColumns[0]},
+ OnDelete: schema.SetNull,
+ },
+ },
+ Indexes: []*schema.Index{
+ {
+ Name: "pendingauthsession_session_token",
+ Unique: true,
+ Columns: []*schema.Column{PendingAuthSessionsColumns[3]},
+ },
+ {
+ Name: "pendingauthsession_target_user_id",
+ Unique: false,
+ Columns: []*schema.Column{PendingAuthSessionsColumns[21]},
+ },
+ {
+ Name: "pendingauthsession_expires_at",
+ Unique: false,
+ Columns: []*schema.Column{PendingAuthSessionsColumns[19]},
+ },
+ {
+ Name: "pendingauthsession_provider_type_provider_key_provider_subject",
+ Unique: false,
+ Columns: []*schema.Column{PendingAuthSessionsColumns[5], PendingAuthSessionsColumns[6], PendingAuthSessionsColumns[7]},
+ },
+ {
+ Name: "pendingauthsession_completion_code_hash",
+ Unique: false,
+ Columns: []*schema.Column{PendingAuthSessionsColumns[14]},
+ },
+ },
+ }
// PromoCodesColumns holds the columns for the "promo_codes" table.
PromoCodesColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
@@ -1079,6 +1271,9 @@ var (
{Name: "totp_secret_encrypted", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
{Name: "totp_enabled", Type: field.TypeBool, Default: false},
{Name: "totp_enabled_at", Type: field.TypeTime, Nullable: true},
+ {Name: "signup_source", Type: field.TypeString, Size: 20, Default: "email"},
+ {Name: "last_login_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "last_active_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "balance_notify_enabled", Type: field.TypeBool, Default: true},
{Name: "balance_notify_threshold_type", Type: field.TypeString, Default: "fixed"},
{Name: "balance_notify_threshold", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
@@ -1318,12 +1513,16 @@ var (
AccountGroupsTable,
AnnouncementsTable,
AnnouncementReadsTable,
+ AuthIdentitiesTable,
+ AuthIdentityChannelsTable,
ErrorPassthroughRulesTable,
GroupsTable,
IdempotencyRecordsTable,
+ IdentityAdoptionDecisionsTable,
PaymentAuditLogsTable,
PaymentOrdersTable,
PaymentProviderInstancesTable,
+ PendingAuthSessionsTable,
PromoCodesTable,
PromoCodeUsagesTable,
ProxiesTable,
@@ -1365,6 +1564,14 @@ func init() {
AnnouncementReadsTable.Annotation = &entsql.Annotation{
Table: "announcement_reads",
}
+ AuthIdentitiesTable.ForeignKeys[0].RefTable = UsersTable
+ AuthIdentitiesTable.Annotation = &entsql.Annotation{
+ Table: "auth_identities",
+ }
+ AuthIdentityChannelsTable.ForeignKeys[0].RefTable = AuthIdentitiesTable
+ AuthIdentityChannelsTable.Annotation = &entsql.Annotation{
+ Table: "auth_identity_channels",
+ }
ErrorPassthroughRulesTable.Annotation = &entsql.Annotation{
Table: "error_passthrough_rules",
}
@@ -1374,6 +1581,11 @@ func init() {
IdempotencyRecordsTable.Annotation = &entsql.Annotation{
Table: "idempotency_records",
}
+ IdentityAdoptionDecisionsTable.ForeignKeys[0].RefTable = AuthIdentitiesTable
+ IdentityAdoptionDecisionsTable.ForeignKeys[1].RefTable = PendingAuthSessionsTable
+ IdentityAdoptionDecisionsTable.Annotation = &entsql.Annotation{
+ Table: "identity_adoption_decisions",
+ }
PaymentAuditLogsTable.Annotation = &entsql.Annotation{
Table: "payment_audit_logs",
}
@@ -1384,6 +1596,10 @@ func init() {
PaymentProviderInstancesTable.Annotation = &entsql.Annotation{
Table: "payment_provider_instances",
}
+ PendingAuthSessionsTable.ForeignKeys[0].RefTable = UsersTable
+ PendingAuthSessionsTable.Annotation = &entsql.Annotation{
+ Table: "pending_auth_sessions",
+ }
PromoCodesTable.Annotation = &entsql.Annotation{
Table: "promo_codes",
}
diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go
index 524ccb92..12905c9a 100644
--- a/backend/ent/mutation.go
+++ b/backend/ent/mutation.go
@@ -17,12 +17,16 @@ import (
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
@@ -51,32 +55,36 @@ const (
OpUpdateOne = ent.OpUpdateOne
// Node types.
- TypeAPIKey = "APIKey"
- TypeAccount = "Account"
- TypeAccountGroup = "AccountGroup"
- TypeAnnouncement = "Announcement"
- TypeAnnouncementRead = "AnnouncementRead"
- TypeErrorPassthroughRule = "ErrorPassthroughRule"
- TypeGroup = "Group"
- TypeIdempotencyRecord = "IdempotencyRecord"
- TypePaymentAuditLog = "PaymentAuditLog"
- TypePaymentOrder = "PaymentOrder"
- TypePaymentProviderInstance = "PaymentProviderInstance"
- TypePromoCode = "PromoCode"
- TypePromoCodeUsage = "PromoCodeUsage"
- TypeProxy = "Proxy"
- TypeRedeemCode = "RedeemCode"
- TypeSecuritySecret = "SecuritySecret"
- TypeSetting = "Setting"
- TypeSubscriptionPlan = "SubscriptionPlan"
- TypeTLSFingerprintProfile = "TLSFingerprintProfile"
- TypeUsageCleanupTask = "UsageCleanupTask"
- TypeUsageLog = "UsageLog"
- TypeUser = "User"
- TypeUserAllowedGroup = "UserAllowedGroup"
- TypeUserAttributeDefinition = "UserAttributeDefinition"
- TypeUserAttributeValue = "UserAttributeValue"
- TypeUserSubscription = "UserSubscription"
+ TypeAPIKey = "APIKey"
+ TypeAccount = "Account"
+ TypeAccountGroup = "AccountGroup"
+ TypeAnnouncement = "Announcement"
+ TypeAnnouncementRead = "AnnouncementRead"
+ TypeAuthIdentity = "AuthIdentity"
+ TypeAuthIdentityChannel = "AuthIdentityChannel"
+ TypeErrorPassthroughRule = "ErrorPassthroughRule"
+ TypeGroup = "Group"
+ TypeIdempotencyRecord = "IdempotencyRecord"
+ TypeIdentityAdoptionDecision = "IdentityAdoptionDecision"
+ TypePaymentAuditLog = "PaymentAuditLog"
+ TypePaymentOrder = "PaymentOrder"
+ TypePaymentProviderInstance = "PaymentProviderInstance"
+ TypePendingAuthSession = "PendingAuthSession"
+ TypePromoCode = "PromoCode"
+ TypePromoCodeUsage = "PromoCodeUsage"
+ TypeProxy = "Proxy"
+ TypeRedeemCode = "RedeemCode"
+ TypeSecuritySecret = "SecuritySecret"
+ TypeSetting = "Setting"
+ TypeSubscriptionPlan = "SubscriptionPlan"
+ TypeTLSFingerprintProfile = "TLSFingerprintProfile"
+ TypeUsageCleanupTask = "UsageCleanupTask"
+ TypeUsageLog = "UsageLog"
+ TypeUser = "User"
+ TypeUserAllowedGroup = "UserAllowedGroup"
+ TypeUserAttributeDefinition = "UserAttributeDefinition"
+ TypeUserAttributeValue = "UserAttributeValue"
+ TypeUserSubscription = "UserSubscription"
)
// APIKeyMutation represents an operation that mutates the APIKey nodes in the graph.
@@ -6887,49 +6895,45 @@ func (m *AnnouncementReadMutation) ResetEdge(name string) error {
return fmt.Errorf("unknown AnnouncementRead edge %s", name)
}
-// ErrorPassthroughRuleMutation represents an operation that mutates the ErrorPassthroughRule nodes in the graph.
-type ErrorPassthroughRuleMutation struct {
+// AuthIdentityMutation represents an operation that mutates the AuthIdentity nodes in the graph.
+type AuthIdentityMutation struct {
config
- op Op
- typ string
- id *int64
- created_at *time.Time
- updated_at *time.Time
- name *string
- enabled *bool
- priority *int
- addpriority *int
- error_codes *[]int
- appenderror_codes []int
- keywords *[]string
- appendkeywords []string
- match_mode *string
- platforms *[]string
- appendplatforms []string
- passthrough_code *bool
- response_code *int
- addresponse_code *int
- passthrough_body *bool
- custom_message *string
- skip_monitoring *bool
- description *string
- clearedFields map[string]struct{}
- done bool
- oldValue func(context.Context) (*ErrorPassthroughRule, error)
- predicates []predicate.ErrorPassthroughRule
+ op Op
+ typ string
+ id *int64
+ created_at *time.Time
+ updated_at *time.Time
+ provider_type *string
+ provider_key *string
+ provider_subject *string
+ verified_at *time.Time
+ issuer *string
+ metadata *map[string]interface{}
+ clearedFields map[string]struct{}
+ user *int64
+ cleareduser bool
+ channels map[int64]struct{}
+ removedchannels map[int64]struct{}
+ clearedchannels bool
+ adoption_decisions map[int64]struct{}
+ removedadoption_decisions map[int64]struct{}
+ clearedadoption_decisions bool
+ done bool
+ oldValue func(context.Context) (*AuthIdentity, error)
+ predicates []predicate.AuthIdentity
}
-var _ ent.Mutation = (*ErrorPassthroughRuleMutation)(nil)
+var _ ent.Mutation = (*AuthIdentityMutation)(nil)
-// errorpassthroughruleOption allows management of the mutation configuration using functional options.
-type errorpassthroughruleOption func(*ErrorPassthroughRuleMutation)
+// authidentityOption allows management of the mutation configuration using functional options.
+type authidentityOption func(*AuthIdentityMutation)
-// newErrorPassthroughRuleMutation creates new mutation for the ErrorPassthroughRule entity.
-func newErrorPassthroughRuleMutation(c config, op Op, opts ...errorpassthroughruleOption) *ErrorPassthroughRuleMutation {
- m := &ErrorPassthroughRuleMutation{
+// newAuthIdentityMutation creates new mutation for the AuthIdentity entity.
+func newAuthIdentityMutation(c config, op Op, opts ...authidentityOption) *AuthIdentityMutation {
+ m := &AuthIdentityMutation{
config: c,
op: op,
- typ: TypeErrorPassthroughRule,
+ typ: TypeAuthIdentity,
clearedFields: make(map[string]struct{}),
}
for _, opt := range opts {
@@ -6938,20 +6942,20 @@ func newErrorPassthroughRuleMutation(c config, op Op, opts ...errorpassthroughru
return m
}
-// withErrorPassthroughRuleID sets the ID field of the mutation.
-func withErrorPassthroughRuleID(id int64) errorpassthroughruleOption {
- return func(m *ErrorPassthroughRuleMutation) {
+// withAuthIdentityID sets the ID field of the mutation.
+func withAuthIdentityID(id int64) authidentityOption {
+ return func(m *AuthIdentityMutation) {
var (
err error
once sync.Once
- value *ErrorPassthroughRule
+ value *AuthIdentity
)
- m.oldValue = func(ctx context.Context) (*ErrorPassthroughRule, error) {
+ m.oldValue = func(ctx context.Context) (*AuthIdentity, error) {
once.Do(func() {
if m.done {
err = errors.New("querying old values post mutation is not allowed")
} else {
- value, err = m.Client().ErrorPassthroughRule.Get(ctx, id)
+ value, err = m.Client().AuthIdentity.Get(ctx, id)
}
})
return value, err
@@ -6960,10 +6964,10 @@ func withErrorPassthroughRuleID(id int64) errorpassthroughruleOption {
}
}
-// withErrorPassthroughRule sets the old ErrorPassthroughRule of the mutation.
-func withErrorPassthroughRule(node *ErrorPassthroughRule) errorpassthroughruleOption {
- return func(m *ErrorPassthroughRuleMutation) {
- m.oldValue = func(context.Context) (*ErrorPassthroughRule, error) {
+// withAuthIdentity sets the old AuthIdentity of the mutation.
+func withAuthIdentity(node *AuthIdentity) authidentityOption {
+ return func(m *AuthIdentityMutation) {
+ m.oldValue = func(context.Context) (*AuthIdentity, error) {
return node, nil
}
m.id = &node.ID
@@ -6972,7 +6976,7 @@ func withErrorPassthroughRule(node *ErrorPassthroughRule) errorpassthroughruleOp
// Client returns a new `ent.Client` from the mutation. If the mutation was
// executed in a transaction (ent.Tx), a transactional client is returned.
-func (m ErrorPassthroughRuleMutation) Client() *Client {
+func (m AuthIdentityMutation) Client() *Client {
client := &Client{config: m.config}
client.init()
return client
@@ -6980,7 +6984,7 @@ func (m ErrorPassthroughRuleMutation) Client() *Client {
// Tx returns an `ent.Tx` for mutations that were executed in transactions;
// it returns an error otherwise.
-func (m ErrorPassthroughRuleMutation) Tx() (*Tx, error) {
+func (m AuthIdentityMutation) Tx() (*Tx, error) {
if _, ok := m.driver.(*txDriver); !ok {
return nil, errors.New("ent: mutation is not running in a transaction")
}
@@ -6991,7 +6995,7 @@ func (m ErrorPassthroughRuleMutation) Tx() (*Tx, error) {
// ID returns the ID value in the mutation. Note that the ID is only available
// if it was provided to the builder or after it was returned from the database.
-func (m *ErrorPassthroughRuleMutation) ID() (id int64, exists bool) {
+func (m *AuthIdentityMutation) ID() (id int64, exists bool) {
if m.id == nil {
return
}
@@ -7002,7 +7006,7 @@ func (m *ErrorPassthroughRuleMutation) ID() (id int64, exists bool) {
// That means, if the mutation is applied within a transaction with an isolation level such
// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
// or updated by the mutation.
-func (m *ErrorPassthroughRuleMutation) IDs(ctx context.Context) ([]int64, error) {
+func (m *AuthIdentityMutation) IDs(ctx context.Context) ([]int64, error) {
switch {
case m.op.Is(OpUpdateOne | OpDeleteOne):
id, exists := m.ID()
@@ -7011,19 +7015,19 @@ func (m *ErrorPassthroughRuleMutation) IDs(ctx context.Context) ([]int64, error)
}
fallthrough
case m.op.Is(OpUpdate | OpDelete):
- return m.Client().ErrorPassthroughRule.Query().Where(m.predicates...).IDs(ctx)
+ return m.Client().AuthIdentity.Query().Where(m.predicates...).IDs(ctx)
default:
return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
}
}
// SetCreatedAt sets the "created_at" field.
-func (m *ErrorPassthroughRuleMutation) SetCreatedAt(t time.Time) {
+func (m *AuthIdentityMutation) SetCreatedAt(t time.Time) {
m.created_at = &t
}
// CreatedAt returns the value of the "created_at" field in the mutation.
-func (m *ErrorPassthroughRuleMutation) CreatedAt() (r time.Time, exists bool) {
+func (m *AuthIdentityMutation) CreatedAt() (r time.Time, exists bool) {
v := m.created_at
if v == nil {
return
@@ -7031,10 +7035,10 @@ func (m *ErrorPassthroughRuleMutation) CreatedAt() (r time.Time, exists bool) {
return *v, true
}
-// OldCreatedAt returns the old "created_at" field's value of the ErrorPassthroughRule entity.
-// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database.
+// OldCreatedAt returns the old "created_at" field's value of the AuthIdentity entity.
+// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *ErrorPassthroughRuleMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
+func (m *AuthIdentityMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations")
}
@@ -7049,17 +7053,17 @@ func (m *ErrorPassthroughRuleMutation) OldCreatedAt(ctx context.Context) (v time
}
// ResetCreatedAt resets all changes to the "created_at" field.
-func (m *ErrorPassthroughRuleMutation) ResetCreatedAt() {
+func (m *AuthIdentityMutation) ResetCreatedAt() {
m.created_at = nil
}
// SetUpdatedAt sets the "updated_at" field.
-func (m *ErrorPassthroughRuleMutation) SetUpdatedAt(t time.Time) {
+func (m *AuthIdentityMutation) SetUpdatedAt(t time.Time) {
m.updated_at = &t
}
// UpdatedAt returns the value of the "updated_at" field in the mutation.
-func (m *ErrorPassthroughRuleMutation) UpdatedAt() (r time.Time, exists bool) {
+func (m *AuthIdentityMutation) UpdatedAt() (r time.Time, exists bool) {
v := m.updated_at
if v == nil {
return
@@ -7067,10 +7071,10 @@ func (m *ErrorPassthroughRuleMutation) UpdatedAt() (r time.Time, exists bool) {
return *v, true
}
-// OldUpdatedAt returns the old "updated_at" field's value of the ErrorPassthroughRule entity.
-// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database.
+// OldUpdatedAt returns the old "updated_at" field's value of the AuthIdentity entity.
+// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *ErrorPassthroughRuleMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) {
+func (m *AuthIdentityMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations")
}
@@ -7085,942 +7089,1510 @@ func (m *ErrorPassthroughRuleMutation) OldUpdatedAt(ctx context.Context) (v time
}
// ResetUpdatedAt resets all changes to the "updated_at" field.
-func (m *ErrorPassthroughRuleMutation) ResetUpdatedAt() {
+func (m *AuthIdentityMutation) ResetUpdatedAt() {
m.updated_at = nil
}
-// SetName sets the "name" field.
-func (m *ErrorPassthroughRuleMutation) SetName(s string) {
- m.name = &s
+// SetUserID sets the "user_id" field.
+func (m *AuthIdentityMutation) SetUserID(i int64) {
+ m.user = &i
}
-// Name returns the value of the "name" field in the mutation.
-func (m *ErrorPassthroughRuleMutation) Name() (r string, exists bool) {
- v := m.name
+// UserID returns the value of the "user_id" field in the mutation.
+func (m *AuthIdentityMutation) UserID() (r int64, exists bool) {
+ v := m.user
if v == nil {
return
}
return *v, true
}
-// OldName returns the old "name" field's value of the ErrorPassthroughRule entity.
-// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database.
+// OldUserID returns the old "user_id" field's value of the AuthIdentity entity.
+// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *ErrorPassthroughRuleMutation) OldName(ctx context.Context) (v string, err error) {
+func (m *AuthIdentityMutation) OldUserID(ctx context.Context) (v int64, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldName is only allowed on UpdateOne operations")
+ return v, errors.New("OldUserID is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldName requires an ID field in the mutation")
+ return v, errors.New("OldUserID requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldName: %w", err)
+ return v, fmt.Errorf("querying old value for OldUserID: %w", err)
}
- return oldValue.Name, nil
+ return oldValue.UserID, nil
}
-// ResetName resets all changes to the "name" field.
-func (m *ErrorPassthroughRuleMutation) ResetName() {
- m.name = nil
+// ResetUserID resets all changes to the "user_id" field.
+func (m *AuthIdentityMutation) ResetUserID() {
+ m.user = nil
}
-// SetEnabled sets the "enabled" field.
-func (m *ErrorPassthroughRuleMutation) SetEnabled(b bool) {
- m.enabled = &b
+// SetProviderType sets the "provider_type" field.
+func (m *AuthIdentityMutation) SetProviderType(s string) {
+ m.provider_type = &s
}
-// Enabled returns the value of the "enabled" field in the mutation.
-func (m *ErrorPassthroughRuleMutation) Enabled() (r bool, exists bool) {
- v := m.enabled
+// ProviderType returns the value of the "provider_type" field in the mutation.
+func (m *AuthIdentityMutation) ProviderType() (r string, exists bool) {
+ v := m.provider_type
if v == nil {
return
}
return *v, true
}
-// OldEnabled returns the old "enabled" field's value of the ErrorPassthroughRule entity.
-// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database.
+// OldProviderType returns the old "provider_type" field's value of the AuthIdentity entity.
+// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *ErrorPassthroughRuleMutation) OldEnabled(ctx context.Context) (v bool, err error) {
+func (m *AuthIdentityMutation) OldProviderType(ctx context.Context) (v string, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldEnabled is only allowed on UpdateOne operations")
+ return v, errors.New("OldProviderType is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldEnabled requires an ID field in the mutation")
+ return v, errors.New("OldProviderType requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldEnabled: %w", err)
+ return v, fmt.Errorf("querying old value for OldProviderType: %w", err)
}
- return oldValue.Enabled, nil
+ return oldValue.ProviderType, nil
}
-// ResetEnabled resets all changes to the "enabled" field.
-func (m *ErrorPassthroughRuleMutation) ResetEnabled() {
- m.enabled = nil
+// ResetProviderType resets all changes to the "provider_type" field.
+func (m *AuthIdentityMutation) ResetProviderType() {
+ m.provider_type = nil
}
-// SetPriority sets the "priority" field.
-func (m *ErrorPassthroughRuleMutation) SetPriority(i int) {
- m.priority = &i
- m.addpriority = nil
+// SetProviderKey sets the "provider_key" field.
+func (m *AuthIdentityMutation) SetProviderKey(s string) {
+ m.provider_key = &s
}
-// Priority returns the value of the "priority" field in the mutation.
-func (m *ErrorPassthroughRuleMutation) Priority() (r int, exists bool) {
- v := m.priority
+// ProviderKey returns the value of the "provider_key" field in the mutation.
+func (m *AuthIdentityMutation) ProviderKey() (r string, exists bool) {
+ v := m.provider_key
if v == nil {
return
}
return *v, true
}
-// OldPriority returns the old "priority" field's value of the ErrorPassthroughRule entity.
-// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database.
+// OldProviderKey returns the old "provider_key" field's value of the AuthIdentity entity.
+// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *ErrorPassthroughRuleMutation) OldPriority(ctx context.Context) (v int, err error) {
+func (m *AuthIdentityMutation) OldProviderKey(ctx context.Context) (v string, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldPriority is only allowed on UpdateOne operations")
+ return v, errors.New("OldProviderKey is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldPriority requires an ID field in the mutation")
+ return v, errors.New("OldProviderKey requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldPriority: %w", err)
+ return v, fmt.Errorf("querying old value for OldProviderKey: %w", err)
}
- return oldValue.Priority, nil
+ return oldValue.ProviderKey, nil
}
-// AddPriority adds i to the "priority" field.
-func (m *ErrorPassthroughRuleMutation) AddPriority(i int) {
- if m.addpriority != nil {
- *m.addpriority += i
- } else {
- m.addpriority = &i
- }
+// ResetProviderKey resets all changes to the "provider_key" field.
+func (m *AuthIdentityMutation) ResetProviderKey() {
+ m.provider_key = nil
}
-// AddedPriority returns the value that was added to the "priority" field in this mutation.
-func (m *ErrorPassthroughRuleMutation) AddedPriority() (r int, exists bool) {
- v := m.addpriority
+// SetProviderSubject sets the "provider_subject" field.
+func (m *AuthIdentityMutation) SetProviderSubject(s string) {
+ m.provider_subject = &s
+}
+
+// ProviderSubject returns the value of the "provider_subject" field in the mutation.
+func (m *AuthIdentityMutation) ProviderSubject() (r string, exists bool) {
+ v := m.provider_subject
if v == nil {
return
}
return *v, true
}
-// ResetPriority resets all changes to the "priority" field.
-func (m *ErrorPassthroughRuleMutation) ResetPriority() {
- m.priority = nil
- m.addpriority = nil
+// OldProviderSubject returns the old "provider_subject" field's value of the AuthIdentity entity.
+// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityMutation) OldProviderSubject(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderSubject is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderSubject requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderSubject: %w", err)
+ }
+ return oldValue.ProviderSubject, nil
}
-// SetErrorCodes sets the "error_codes" field.
-func (m *ErrorPassthroughRuleMutation) SetErrorCodes(i []int) {
- m.error_codes = &i
- m.appenderror_codes = nil
+// ResetProviderSubject resets all changes to the "provider_subject" field.
+func (m *AuthIdentityMutation) ResetProviderSubject() {
+ m.provider_subject = nil
}
-// ErrorCodes returns the value of the "error_codes" field in the mutation.
-func (m *ErrorPassthroughRuleMutation) ErrorCodes() (r []int, exists bool) {
- v := m.error_codes
+// SetVerifiedAt sets the "verified_at" field.
+func (m *AuthIdentityMutation) SetVerifiedAt(t time.Time) {
+ m.verified_at = &t
+}
+
+// VerifiedAt returns the value of the "verified_at" field in the mutation.
+func (m *AuthIdentityMutation) VerifiedAt() (r time.Time, exists bool) {
+ v := m.verified_at
if v == nil {
return
}
return *v, true
}
-// OldErrorCodes returns the old "error_codes" field's value of the ErrorPassthroughRule entity.
-// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database.
+// OldVerifiedAt returns the old "verified_at" field's value of the AuthIdentity entity.
+// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *ErrorPassthroughRuleMutation) OldErrorCodes(ctx context.Context) (v []int, err error) {
+func (m *AuthIdentityMutation) OldVerifiedAt(ctx context.Context) (v *time.Time, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldErrorCodes is only allowed on UpdateOne operations")
+ return v, errors.New("OldVerifiedAt is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldErrorCodes requires an ID field in the mutation")
+ return v, errors.New("OldVerifiedAt requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldErrorCodes: %w", err)
- }
- return oldValue.ErrorCodes, nil
-}
-
-// AppendErrorCodes adds i to the "error_codes" field.
-func (m *ErrorPassthroughRuleMutation) AppendErrorCodes(i []int) {
- m.appenderror_codes = append(m.appenderror_codes, i...)
-}
-
-// AppendedErrorCodes returns the list of values that were appended to the "error_codes" field in this mutation.
-func (m *ErrorPassthroughRuleMutation) AppendedErrorCodes() ([]int, bool) {
- if len(m.appenderror_codes) == 0 {
- return nil, false
+ return v, fmt.Errorf("querying old value for OldVerifiedAt: %w", err)
}
- return m.appenderror_codes, true
+ return oldValue.VerifiedAt, nil
}
-// ClearErrorCodes clears the value of the "error_codes" field.
-func (m *ErrorPassthroughRuleMutation) ClearErrorCodes() {
- m.error_codes = nil
- m.appenderror_codes = nil
- m.clearedFields[errorpassthroughrule.FieldErrorCodes] = struct{}{}
+// ClearVerifiedAt clears the value of the "verified_at" field.
+func (m *AuthIdentityMutation) ClearVerifiedAt() {
+ m.verified_at = nil
+ m.clearedFields[authidentity.FieldVerifiedAt] = struct{}{}
}
-// ErrorCodesCleared returns if the "error_codes" field was cleared in this mutation.
-func (m *ErrorPassthroughRuleMutation) ErrorCodesCleared() bool {
- _, ok := m.clearedFields[errorpassthroughrule.FieldErrorCodes]
+// VerifiedAtCleared returns if the "verified_at" field was cleared in this mutation.
+func (m *AuthIdentityMutation) VerifiedAtCleared() bool {
+ _, ok := m.clearedFields[authidentity.FieldVerifiedAt]
return ok
}
-// ResetErrorCodes resets all changes to the "error_codes" field.
-func (m *ErrorPassthroughRuleMutation) ResetErrorCodes() {
- m.error_codes = nil
- m.appenderror_codes = nil
- delete(m.clearedFields, errorpassthroughrule.FieldErrorCodes)
+// ResetVerifiedAt resets all changes to the "verified_at" field.
+func (m *AuthIdentityMutation) ResetVerifiedAt() {
+ m.verified_at = nil
+ delete(m.clearedFields, authidentity.FieldVerifiedAt)
}
-// SetKeywords sets the "keywords" field.
-func (m *ErrorPassthroughRuleMutation) SetKeywords(s []string) {
- m.keywords = &s
- m.appendkeywords = nil
+// SetIssuer sets the "issuer" field.
+func (m *AuthIdentityMutation) SetIssuer(s string) {
+ m.issuer = &s
}
-// Keywords returns the value of the "keywords" field in the mutation.
-func (m *ErrorPassthroughRuleMutation) Keywords() (r []string, exists bool) {
- v := m.keywords
+// Issuer returns the value of the "issuer" field in the mutation.
+func (m *AuthIdentityMutation) Issuer() (r string, exists bool) {
+ v := m.issuer
if v == nil {
return
}
return *v, true
}
-// OldKeywords returns the old "keywords" field's value of the ErrorPassthroughRule entity.
-// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database.
+// OldIssuer returns the old "issuer" field's value of the AuthIdentity entity.
+// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *ErrorPassthroughRuleMutation) OldKeywords(ctx context.Context) (v []string, err error) {
+func (m *AuthIdentityMutation) OldIssuer(ctx context.Context) (v *string, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldKeywords is only allowed on UpdateOne operations")
+ return v, errors.New("OldIssuer is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldKeywords requires an ID field in the mutation")
+ return v, errors.New("OldIssuer requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldKeywords: %w", err)
- }
- return oldValue.Keywords, nil
-}
-
-// AppendKeywords adds s to the "keywords" field.
-func (m *ErrorPassthroughRuleMutation) AppendKeywords(s []string) {
- m.appendkeywords = append(m.appendkeywords, s...)
-}
-
-// AppendedKeywords returns the list of values that were appended to the "keywords" field in this mutation.
-func (m *ErrorPassthroughRuleMutation) AppendedKeywords() ([]string, bool) {
- if len(m.appendkeywords) == 0 {
- return nil, false
+ return v, fmt.Errorf("querying old value for OldIssuer: %w", err)
}
- return m.appendkeywords, true
+ return oldValue.Issuer, nil
}
-// ClearKeywords clears the value of the "keywords" field.
-func (m *ErrorPassthroughRuleMutation) ClearKeywords() {
- m.keywords = nil
- m.appendkeywords = nil
- m.clearedFields[errorpassthroughrule.FieldKeywords] = struct{}{}
+// ClearIssuer clears the value of the "issuer" field.
+func (m *AuthIdentityMutation) ClearIssuer() {
+ m.issuer = nil
+ m.clearedFields[authidentity.FieldIssuer] = struct{}{}
}
-// KeywordsCleared returns if the "keywords" field was cleared in this mutation.
-func (m *ErrorPassthroughRuleMutation) KeywordsCleared() bool {
- _, ok := m.clearedFields[errorpassthroughrule.FieldKeywords]
+// IssuerCleared returns if the "issuer" field was cleared in this mutation.
+func (m *AuthIdentityMutation) IssuerCleared() bool {
+ _, ok := m.clearedFields[authidentity.FieldIssuer]
return ok
}
-// ResetKeywords resets all changes to the "keywords" field.
-func (m *ErrorPassthroughRuleMutation) ResetKeywords() {
- m.keywords = nil
- m.appendkeywords = nil
- delete(m.clearedFields, errorpassthroughrule.FieldKeywords)
+// ResetIssuer resets all changes to the "issuer" field.
+func (m *AuthIdentityMutation) ResetIssuer() {
+ m.issuer = nil
+ delete(m.clearedFields, authidentity.FieldIssuer)
}
-// SetMatchMode sets the "match_mode" field.
-func (m *ErrorPassthroughRuleMutation) SetMatchMode(s string) {
- m.match_mode = &s
+// SetMetadata sets the "metadata" field.
+func (m *AuthIdentityMutation) SetMetadata(value map[string]interface{}) {
+ m.metadata = &value
}
-// MatchMode returns the value of the "match_mode" field in the mutation.
-func (m *ErrorPassthroughRuleMutation) MatchMode() (r string, exists bool) {
- v := m.match_mode
+// Metadata returns the value of the "metadata" field in the mutation.
+func (m *AuthIdentityMutation) Metadata() (r map[string]interface{}, exists bool) {
+ v := m.metadata
if v == nil {
return
}
return *v, true
}
-// OldMatchMode returns the old "match_mode" field's value of the ErrorPassthroughRule entity.
-// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database.
+// OldMetadata returns the old "metadata" field's value of the AuthIdentity entity.
+// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *ErrorPassthroughRuleMutation) OldMatchMode(ctx context.Context) (v string, err error) {
+func (m *AuthIdentityMutation) OldMetadata(ctx context.Context) (v map[string]interface{}, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldMatchMode is only allowed on UpdateOne operations")
+ return v, errors.New("OldMetadata is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldMatchMode requires an ID field in the mutation")
+ return v, errors.New("OldMetadata requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldMatchMode: %w", err)
+ return v, fmt.Errorf("querying old value for OldMetadata: %w", err)
}
- return oldValue.MatchMode, nil
+ return oldValue.Metadata, nil
}
-// ResetMatchMode resets all changes to the "match_mode" field.
-func (m *ErrorPassthroughRuleMutation) ResetMatchMode() {
- m.match_mode = nil
+// ResetMetadata resets all changes to the "metadata" field.
+func (m *AuthIdentityMutation) ResetMetadata() {
+ m.metadata = nil
}
-// SetPlatforms sets the "platforms" field.
-func (m *ErrorPassthroughRuleMutation) SetPlatforms(s []string) {
- m.platforms = &s
- m.appendplatforms = nil
+// ClearUser clears the "user" edge to the User entity.
+func (m *AuthIdentityMutation) ClearUser() {
+ m.cleareduser = true
+ m.clearedFields[authidentity.FieldUserID] = struct{}{}
}
-// Platforms returns the value of the "platforms" field in the mutation.
-func (m *ErrorPassthroughRuleMutation) Platforms() (r []string, exists bool) {
- v := m.platforms
- if v == nil {
- return
- }
- return *v, true
+// UserCleared reports if the "user" edge to the User entity was cleared.
+func (m *AuthIdentityMutation) UserCleared() bool {
+ return m.cleareduser
}
-// OldPlatforms returns the old "platforms" field's value of the ErrorPassthroughRule entity.
-// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *ErrorPassthroughRuleMutation) OldPlatforms(ctx context.Context) (v []string, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldPlatforms is only allowed on UpdateOne operations")
- }
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldPlatforms requires an ID field in the mutation")
- }
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldPlatforms: %w", err)
+// UserIDs returns the "user" edge IDs in the mutation.
+// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use
+// UserID instead. It exists only for internal usage by the builders.
+func (m *AuthIdentityMutation) UserIDs() (ids []int64) {
+ if id := m.user; id != nil {
+ ids = append(ids, *id)
}
- return oldValue.Platforms, nil
+ return
}
-// AppendPlatforms adds s to the "platforms" field.
-func (m *ErrorPassthroughRuleMutation) AppendPlatforms(s []string) {
- m.appendplatforms = append(m.appendplatforms, s...)
+// ResetUser resets all changes to the "user" edge.
+func (m *AuthIdentityMutation) ResetUser() {
+ m.user = nil
+ m.cleareduser = false
}
-// AppendedPlatforms returns the list of values that were appended to the "platforms" field in this mutation.
-func (m *ErrorPassthroughRuleMutation) AppendedPlatforms() ([]string, bool) {
- if len(m.appendplatforms) == 0 {
- return nil, false
+// AddChannelIDs adds the "channels" edge to the AuthIdentityChannel entity by ids.
+func (m *AuthIdentityMutation) AddChannelIDs(ids ...int64) {
+ if m.channels == nil {
+ m.channels = make(map[int64]struct{})
+ }
+ for i := range ids {
+ m.channels[ids[i]] = struct{}{}
}
- return m.appendplatforms, true
}
-// ClearPlatforms clears the value of the "platforms" field.
-func (m *ErrorPassthroughRuleMutation) ClearPlatforms() {
- m.platforms = nil
- m.appendplatforms = nil
- m.clearedFields[errorpassthroughrule.FieldPlatforms] = struct{}{}
+// ClearChannels clears the "channels" edge to the AuthIdentityChannel entity.
+func (m *AuthIdentityMutation) ClearChannels() {
+ m.clearedchannels = true
}
-// PlatformsCleared returns if the "platforms" field was cleared in this mutation.
-func (m *ErrorPassthroughRuleMutation) PlatformsCleared() bool {
- _, ok := m.clearedFields[errorpassthroughrule.FieldPlatforms]
- return ok
+// ChannelsCleared reports if the "channels" edge to the AuthIdentityChannel entity was cleared.
+func (m *AuthIdentityMutation) ChannelsCleared() bool {
+ return m.clearedchannels
}
-// ResetPlatforms resets all changes to the "platforms" field.
-func (m *ErrorPassthroughRuleMutation) ResetPlatforms() {
- m.platforms = nil
- m.appendplatforms = nil
- delete(m.clearedFields, errorpassthroughrule.FieldPlatforms)
+// RemoveChannelIDs removes the "channels" edge to the AuthIdentityChannel entity by IDs.
+func (m *AuthIdentityMutation) RemoveChannelIDs(ids ...int64) {
+ if m.removedchannels == nil {
+ m.removedchannels = make(map[int64]struct{})
+ }
+ for i := range ids {
+ delete(m.channels, ids[i])
+ m.removedchannels[ids[i]] = struct{}{}
+ }
}
-// SetPassthroughCode sets the "passthrough_code" field.
-func (m *ErrorPassthroughRuleMutation) SetPassthroughCode(b bool) {
- m.passthrough_code = &b
+// RemovedChannels returns the removed IDs of the "channels" edge to the AuthIdentityChannel entity.
+func (m *AuthIdentityMutation) RemovedChannelsIDs() (ids []int64) {
+ for id := range m.removedchannels {
+ ids = append(ids, id)
+ }
+ return
}
-// PassthroughCode returns the value of the "passthrough_code" field in the mutation.
-func (m *ErrorPassthroughRuleMutation) PassthroughCode() (r bool, exists bool) {
- v := m.passthrough_code
- if v == nil {
- return
+// ChannelsIDs returns the "channels" edge IDs in the mutation.
+func (m *AuthIdentityMutation) ChannelsIDs() (ids []int64) {
+ for id := range m.channels {
+ ids = append(ids, id)
}
- return *v, true
+ return
}
-// OldPassthroughCode returns the old "passthrough_code" field's value of the ErrorPassthroughRule entity.
-// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *ErrorPassthroughRuleMutation) OldPassthroughCode(ctx context.Context) (v bool, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldPassthroughCode is only allowed on UpdateOne operations")
- }
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldPassthroughCode requires an ID field in the mutation")
+// ResetChannels resets all changes to the "channels" edge.
+func (m *AuthIdentityMutation) ResetChannels() {
+ m.channels = nil
+ m.clearedchannels = false
+ m.removedchannels = nil
+}
+
+// AddAdoptionDecisionIDs adds the "adoption_decisions" edge to the IdentityAdoptionDecision entity by ids.
+func (m *AuthIdentityMutation) AddAdoptionDecisionIDs(ids ...int64) {
+ if m.adoption_decisions == nil {
+ m.adoption_decisions = make(map[int64]struct{})
}
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldPassthroughCode: %w", err)
+ for i := range ids {
+ m.adoption_decisions[ids[i]] = struct{}{}
}
- return oldValue.PassthroughCode, nil
}
-// ResetPassthroughCode resets all changes to the "passthrough_code" field.
-func (m *ErrorPassthroughRuleMutation) ResetPassthroughCode() {
- m.passthrough_code = nil
+// ClearAdoptionDecisions clears the "adoption_decisions" edge to the IdentityAdoptionDecision entity.
+func (m *AuthIdentityMutation) ClearAdoptionDecisions() {
+ m.clearedadoption_decisions = true
}
-// SetResponseCode sets the "response_code" field.
-func (m *ErrorPassthroughRuleMutation) SetResponseCode(i int) {
- m.response_code = &i
- m.addresponse_code = nil
+// AdoptionDecisionsCleared reports if the "adoption_decisions" edge to the IdentityAdoptionDecision entity was cleared.
+func (m *AuthIdentityMutation) AdoptionDecisionsCleared() bool {
+ return m.clearedadoption_decisions
}
-// ResponseCode returns the value of the "response_code" field in the mutation.
-func (m *ErrorPassthroughRuleMutation) ResponseCode() (r int, exists bool) {
- v := m.response_code
- if v == nil {
- return
+// RemoveAdoptionDecisionIDs removes the "adoption_decisions" edge to the IdentityAdoptionDecision entity by IDs.
+func (m *AuthIdentityMutation) RemoveAdoptionDecisionIDs(ids ...int64) {
+ if m.removedadoption_decisions == nil {
+ m.removedadoption_decisions = make(map[int64]struct{})
+ }
+ for i := range ids {
+ delete(m.adoption_decisions, ids[i])
+ m.removedadoption_decisions[ids[i]] = struct{}{}
}
- return *v, true
}
-// OldResponseCode returns the old "response_code" field's value of the ErrorPassthroughRule entity.
-// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *ErrorPassthroughRuleMutation) OldResponseCode(ctx context.Context) (v *int, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldResponseCode is only allowed on UpdateOne operations")
- }
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldResponseCode requires an ID field in the mutation")
- }
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldResponseCode: %w", err)
+// RemovedAdoptionDecisions returns the removed IDs of the "adoption_decisions" edge to the IdentityAdoptionDecision entity.
+func (m *AuthIdentityMutation) RemovedAdoptionDecisionsIDs() (ids []int64) {
+ for id := range m.removedadoption_decisions {
+ ids = append(ids, id)
}
- return oldValue.ResponseCode, nil
+ return
}
-// AddResponseCode adds i to the "response_code" field.
-func (m *ErrorPassthroughRuleMutation) AddResponseCode(i int) {
- if m.addresponse_code != nil {
- *m.addresponse_code += i
- } else {
- m.addresponse_code = &i
+// AdoptionDecisionsIDs returns the "adoption_decisions" edge IDs in the mutation.
+func (m *AuthIdentityMutation) AdoptionDecisionsIDs() (ids []int64) {
+ for id := range m.adoption_decisions {
+ ids = append(ids, id)
}
+ return
}
-// AddedResponseCode returns the value that was added to the "response_code" field in this mutation.
-func (m *ErrorPassthroughRuleMutation) AddedResponseCode() (r int, exists bool) {
- v := m.addresponse_code
- if v == nil {
- return
- }
- return *v, true
+// ResetAdoptionDecisions resets all changes to the "adoption_decisions" edge.
+func (m *AuthIdentityMutation) ResetAdoptionDecisions() {
+ m.adoption_decisions = nil
+ m.clearedadoption_decisions = false
+ m.removedadoption_decisions = nil
}
-// ClearResponseCode clears the value of the "response_code" field.
-func (m *ErrorPassthroughRuleMutation) ClearResponseCode() {
- m.response_code = nil
- m.addresponse_code = nil
- m.clearedFields[errorpassthroughrule.FieldResponseCode] = struct{}{}
+// Where appends a list predicates to the AuthIdentityMutation builder.
+func (m *AuthIdentityMutation) Where(ps ...predicate.AuthIdentity) {
+ m.predicates = append(m.predicates, ps...)
}
-// ResponseCodeCleared returns if the "response_code" field was cleared in this mutation.
-func (m *ErrorPassthroughRuleMutation) ResponseCodeCleared() bool {
- _, ok := m.clearedFields[errorpassthroughrule.FieldResponseCode]
- return ok
+// WhereP appends storage-level predicates to the AuthIdentityMutation builder. Using this method,
+// users can use type-assertion to append predicates that do not depend on any generated package.
+func (m *AuthIdentityMutation) WhereP(ps ...func(*sql.Selector)) {
+ p := make([]predicate.AuthIdentity, len(ps))
+ for i := range ps {
+ p[i] = ps[i]
+ }
+ m.Where(p...)
}
-// ResetResponseCode resets all changes to the "response_code" field.
-func (m *ErrorPassthroughRuleMutation) ResetResponseCode() {
- m.response_code = nil
- m.addresponse_code = nil
- delete(m.clearedFields, errorpassthroughrule.FieldResponseCode)
+// Op returns the operation name.
+func (m *AuthIdentityMutation) Op() Op {
+ return m.op
}
-// SetPassthroughBody sets the "passthrough_body" field.
-func (m *ErrorPassthroughRuleMutation) SetPassthroughBody(b bool) {
- m.passthrough_body = &b
+// SetOp allows setting the mutation operation.
+func (m *AuthIdentityMutation) SetOp(op Op) {
+ m.op = op
}
-// PassthroughBody returns the value of the "passthrough_body" field in the mutation.
-func (m *ErrorPassthroughRuleMutation) PassthroughBody() (r bool, exists bool) {
- v := m.passthrough_body
- if v == nil {
- return
- }
- return *v, true
+// Type returns the node type of this mutation (AuthIdentity).
+func (m *AuthIdentityMutation) Type() string {
+ return m.typ
}
-// OldPassthroughBody returns the old "passthrough_body" field's value of the ErrorPassthroughRule entity.
-// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *ErrorPassthroughRuleMutation) OldPassthroughBody(ctx context.Context) (v bool, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldPassthroughBody is only allowed on UpdateOne operations")
+// Fields returns all fields that were changed during this mutation. Note that in
+// order to get all numeric fields that were incremented/decremented, call
+// AddedFields().
+func (m *AuthIdentityMutation) Fields() []string {
+ fields := make([]string, 0, 9)
+ if m.created_at != nil {
+ fields = append(fields, authidentity.FieldCreatedAt)
}
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldPassthroughBody requires an ID field in the mutation")
+ if m.updated_at != nil {
+ fields = append(fields, authidentity.FieldUpdatedAt)
}
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldPassthroughBody: %w", err)
+ if m.user != nil {
+ fields = append(fields, authidentity.FieldUserID)
}
- return oldValue.PassthroughBody, nil
-}
-
-// ResetPassthroughBody resets all changes to the "passthrough_body" field.
-func (m *ErrorPassthroughRuleMutation) ResetPassthroughBody() {
- m.passthrough_body = nil
-}
-
-// SetCustomMessage sets the "custom_message" field.
-func (m *ErrorPassthroughRuleMutation) SetCustomMessage(s string) {
- m.custom_message = &s
-}
-
-// CustomMessage returns the value of the "custom_message" field in the mutation.
-func (m *ErrorPassthroughRuleMutation) CustomMessage() (r string, exists bool) {
- v := m.custom_message
- if v == nil {
- return
+ if m.provider_type != nil {
+ fields = append(fields, authidentity.FieldProviderType)
}
- return *v, true
-}
-
-// OldCustomMessage returns the old "custom_message" field's value of the ErrorPassthroughRule entity.
-// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *ErrorPassthroughRuleMutation) OldCustomMessage(ctx context.Context) (v *string, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldCustomMessage is only allowed on UpdateOne operations")
+ if m.provider_key != nil {
+ fields = append(fields, authidentity.FieldProviderKey)
}
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldCustomMessage requires an ID field in the mutation")
+ if m.provider_subject != nil {
+ fields = append(fields, authidentity.FieldProviderSubject)
}
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldCustomMessage: %w", err)
+ if m.verified_at != nil {
+ fields = append(fields, authidentity.FieldVerifiedAt)
}
- return oldValue.CustomMessage, nil
-}
-
-// ClearCustomMessage clears the value of the "custom_message" field.
-func (m *ErrorPassthroughRuleMutation) ClearCustomMessage() {
- m.custom_message = nil
- m.clearedFields[errorpassthroughrule.FieldCustomMessage] = struct{}{}
+ if m.issuer != nil {
+ fields = append(fields, authidentity.FieldIssuer)
+ }
+ if m.metadata != nil {
+ fields = append(fields, authidentity.FieldMetadata)
+ }
+ return fields
}
-// CustomMessageCleared returns if the "custom_message" field was cleared in this mutation.
-func (m *ErrorPassthroughRuleMutation) CustomMessageCleared() bool {
- _, ok := m.clearedFields[errorpassthroughrule.FieldCustomMessage]
- return ok
+// Field returns the value of a field with the given name. The second boolean
+// return value indicates that this field was not set, or was not defined in the
+// schema.
+func (m *AuthIdentityMutation) Field(name string) (ent.Value, bool) {
+ switch name {
+ case authidentity.FieldCreatedAt:
+ return m.CreatedAt()
+ case authidentity.FieldUpdatedAt:
+ return m.UpdatedAt()
+ case authidentity.FieldUserID:
+ return m.UserID()
+ case authidentity.FieldProviderType:
+ return m.ProviderType()
+ case authidentity.FieldProviderKey:
+ return m.ProviderKey()
+ case authidentity.FieldProviderSubject:
+ return m.ProviderSubject()
+ case authidentity.FieldVerifiedAt:
+ return m.VerifiedAt()
+ case authidentity.FieldIssuer:
+ return m.Issuer()
+ case authidentity.FieldMetadata:
+ return m.Metadata()
+ }
+ return nil, false
}
-// ResetCustomMessage resets all changes to the "custom_message" field.
-func (m *ErrorPassthroughRuleMutation) ResetCustomMessage() {
- m.custom_message = nil
- delete(m.clearedFields, errorpassthroughrule.FieldCustomMessage)
+// OldField returns the old value of the field from the database. An error is
+// returned if the mutation operation is not UpdateOne, or the query to the
+// database failed.
+func (m *AuthIdentityMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
+ switch name {
+ case authidentity.FieldCreatedAt:
+ return m.OldCreatedAt(ctx)
+ case authidentity.FieldUpdatedAt:
+ return m.OldUpdatedAt(ctx)
+ case authidentity.FieldUserID:
+ return m.OldUserID(ctx)
+ case authidentity.FieldProviderType:
+ return m.OldProviderType(ctx)
+ case authidentity.FieldProviderKey:
+ return m.OldProviderKey(ctx)
+ case authidentity.FieldProviderSubject:
+ return m.OldProviderSubject(ctx)
+ case authidentity.FieldVerifiedAt:
+ return m.OldVerifiedAt(ctx)
+ case authidentity.FieldIssuer:
+ return m.OldIssuer(ctx)
+ case authidentity.FieldMetadata:
+ return m.OldMetadata(ctx)
+ }
+ return nil, fmt.Errorf("unknown AuthIdentity field %s", name)
}
-// SetSkipMonitoring sets the "skip_monitoring" field.
-func (m *ErrorPassthroughRuleMutation) SetSkipMonitoring(b bool) {
- m.skip_monitoring = &b
+// SetField sets the value of a field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *AuthIdentityMutation) SetField(name string, value ent.Value) error {
+ switch name {
+ case authidentity.FieldCreatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCreatedAt(v)
+ return nil
+ case authidentity.FieldUpdatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUpdatedAt(v)
+ return nil
+ case authidentity.FieldUserID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUserID(v)
+ return nil
+ case authidentity.FieldProviderType:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderType(v)
+ return nil
+ case authidentity.FieldProviderKey:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderKey(v)
+ return nil
+ case authidentity.FieldProviderSubject:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderSubject(v)
+ return nil
+ case authidentity.FieldVerifiedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetVerifiedAt(v)
+ return nil
+ case authidentity.FieldIssuer:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetIssuer(v)
+ return nil
+ case authidentity.FieldMetadata:
+ v, ok := value.(map[string]interface{})
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetMetadata(v)
+ return nil
+ }
+ return fmt.Errorf("unknown AuthIdentity field %s", name)
}
-// SkipMonitoring returns the value of the "skip_monitoring" field in the mutation.
-func (m *ErrorPassthroughRuleMutation) SkipMonitoring() (r bool, exists bool) {
- v := m.skip_monitoring
- if v == nil {
- return
+// AddedFields returns all numeric fields that were incremented/decremented during
+// this mutation.
+func (m *AuthIdentityMutation) AddedFields() []string {
+ var fields []string
+ return fields
+}
+
+// AddedField returns the numeric value that was incremented/decremented on a field
+// with the given name. The second boolean return value indicates that this field
+// was not set, or was not defined in the schema.
+func (m *AuthIdentityMutation) AddedField(name string) (ent.Value, bool) {
+ switch name {
}
- return *v, true
+ return nil, false
}
-// OldSkipMonitoring returns the old "skip_monitoring" field's value of the ErrorPassthroughRule entity.
-// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *ErrorPassthroughRuleMutation) OldSkipMonitoring(ctx context.Context) (v bool, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldSkipMonitoring is only allowed on UpdateOne operations")
+// AddField adds the value to the field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *AuthIdentityMutation) AddField(name string, value ent.Value) error {
+ switch name {
}
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldSkipMonitoring requires an ID field in the mutation")
+ return fmt.Errorf("unknown AuthIdentity numeric field %s", name)
+}
+
+// ClearedFields returns all nullable fields that were cleared during this
+// mutation.
+func (m *AuthIdentityMutation) ClearedFields() []string {
+ var fields []string
+ if m.FieldCleared(authidentity.FieldVerifiedAt) {
+ fields = append(fields, authidentity.FieldVerifiedAt)
}
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldSkipMonitoring: %w", err)
+ if m.FieldCleared(authidentity.FieldIssuer) {
+ fields = append(fields, authidentity.FieldIssuer)
}
- return oldValue.SkipMonitoring, nil
+ return fields
}
-// ResetSkipMonitoring resets all changes to the "skip_monitoring" field.
-func (m *ErrorPassthroughRuleMutation) ResetSkipMonitoring() {
- m.skip_monitoring = nil
+// FieldCleared returns a boolean indicating if a field with the given name was
+// cleared in this mutation.
+func (m *AuthIdentityMutation) FieldCleared(name string) bool {
+ _, ok := m.clearedFields[name]
+ return ok
}
-// SetDescription sets the "description" field.
-func (m *ErrorPassthroughRuleMutation) SetDescription(s string) {
- m.description = &s
+// ClearField clears the value of the field with the given name. It returns an
+// error if the field is not defined in the schema.
+func (m *AuthIdentityMutation) ClearField(name string) error {
+ switch name {
+ case authidentity.FieldVerifiedAt:
+ m.ClearVerifiedAt()
+ return nil
+ case authidentity.FieldIssuer:
+ m.ClearIssuer()
+ return nil
+ }
+ return fmt.Errorf("unknown AuthIdentity nullable field %s", name)
}
-// Description returns the value of the "description" field in the mutation.
-func (m *ErrorPassthroughRuleMutation) Description() (r string, exists bool) {
- v := m.description
- if v == nil {
- return
+// ResetField resets all changes in the mutation for the field with the given name.
+// It returns an error if the field is not defined in the schema.
+func (m *AuthIdentityMutation) ResetField(name string) error {
+ switch name {
+ case authidentity.FieldCreatedAt:
+ m.ResetCreatedAt()
+ return nil
+ case authidentity.FieldUpdatedAt:
+ m.ResetUpdatedAt()
+ return nil
+ case authidentity.FieldUserID:
+ m.ResetUserID()
+ return nil
+ case authidentity.FieldProviderType:
+ m.ResetProviderType()
+ return nil
+ case authidentity.FieldProviderKey:
+ m.ResetProviderKey()
+ return nil
+ case authidentity.FieldProviderSubject:
+ m.ResetProviderSubject()
+ return nil
+ case authidentity.FieldVerifiedAt:
+ m.ResetVerifiedAt()
+ return nil
+ case authidentity.FieldIssuer:
+ m.ResetIssuer()
+ return nil
+ case authidentity.FieldMetadata:
+ m.ResetMetadata()
+ return nil
}
- return *v, true
+ return fmt.Errorf("unknown AuthIdentity field %s", name)
}
-// OldDescription returns the old "description" field's value of the ErrorPassthroughRule entity.
-// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *ErrorPassthroughRuleMutation) OldDescription(ctx context.Context) (v *string, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldDescription is only allowed on UpdateOne operations")
+// AddedEdges returns all edge names that were set/added in this mutation.
+func (m *AuthIdentityMutation) AddedEdges() []string {
+ edges := make([]string, 0, 3)
+ if m.user != nil {
+ edges = append(edges, authidentity.EdgeUser)
}
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldDescription requires an ID field in the mutation")
+ if m.channels != nil {
+ edges = append(edges, authidentity.EdgeChannels)
}
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldDescription: %w", err)
+ if m.adoption_decisions != nil {
+ edges = append(edges, authidentity.EdgeAdoptionDecisions)
}
- return oldValue.Description, nil
+ return edges
}
-// ClearDescription clears the value of the "description" field.
-func (m *ErrorPassthroughRuleMutation) ClearDescription() {
- m.description = nil
- m.clearedFields[errorpassthroughrule.FieldDescription] = struct{}{}
+// AddedIDs returns all IDs (to other nodes) that were added for the given edge
+// name in this mutation.
+func (m *AuthIdentityMutation) AddedIDs(name string) []ent.Value {
+ switch name {
+ case authidentity.EdgeUser:
+ if id := m.user; id != nil {
+ return []ent.Value{*id}
+ }
+ case authidentity.EdgeChannels:
+ ids := make([]ent.Value, 0, len(m.channels))
+ for id := range m.channels {
+ ids = append(ids, id)
+ }
+ return ids
+ case authidentity.EdgeAdoptionDecisions:
+ ids := make([]ent.Value, 0, len(m.adoption_decisions))
+ for id := range m.adoption_decisions {
+ ids = append(ids, id)
+ }
+ return ids
+ }
+ return nil
}
-// DescriptionCleared returns if the "description" field was cleared in this mutation.
-func (m *ErrorPassthroughRuleMutation) DescriptionCleared() bool {
- _, ok := m.clearedFields[errorpassthroughrule.FieldDescription]
- return ok
+// RemovedEdges returns all edge names that were removed in this mutation.
+func (m *AuthIdentityMutation) RemovedEdges() []string {
+ edges := make([]string, 0, 3)
+ if m.removedchannels != nil {
+ edges = append(edges, authidentity.EdgeChannels)
+ }
+ if m.removedadoption_decisions != nil {
+ edges = append(edges, authidentity.EdgeAdoptionDecisions)
+ }
+ return edges
}
-// ResetDescription resets all changes to the "description" field.
-func (m *ErrorPassthroughRuleMutation) ResetDescription() {
- m.description = nil
- delete(m.clearedFields, errorpassthroughrule.FieldDescription)
+// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
+// the given name in this mutation.
+func (m *AuthIdentityMutation) RemovedIDs(name string) []ent.Value {
+ switch name {
+ case authidentity.EdgeChannels:
+ ids := make([]ent.Value, 0, len(m.removedchannels))
+ for id := range m.removedchannels {
+ ids = append(ids, id)
+ }
+ return ids
+ case authidentity.EdgeAdoptionDecisions:
+ ids := make([]ent.Value, 0, len(m.removedadoption_decisions))
+ for id := range m.removedadoption_decisions {
+ ids = append(ids, id)
+ }
+ return ids
+ }
+ return nil
}
-// Where appends a list predicates to the ErrorPassthroughRuleMutation builder.
-func (m *ErrorPassthroughRuleMutation) Where(ps ...predicate.ErrorPassthroughRule) {
- m.predicates = append(m.predicates, ps...)
+// ClearedEdges returns all edge names that were cleared in this mutation.
+func (m *AuthIdentityMutation) ClearedEdges() []string {
+ edges := make([]string, 0, 3)
+ if m.cleareduser {
+ edges = append(edges, authidentity.EdgeUser)
+ }
+ if m.clearedchannels {
+ edges = append(edges, authidentity.EdgeChannels)
+ }
+ if m.clearedadoption_decisions {
+ edges = append(edges, authidentity.EdgeAdoptionDecisions)
+ }
+ return edges
}
-// WhereP appends storage-level predicates to the ErrorPassthroughRuleMutation builder. Using this method,
-// users can use type-assertion to append predicates that do not depend on any generated package.
-func (m *ErrorPassthroughRuleMutation) WhereP(ps ...func(*sql.Selector)) {
- p := make([]predicate.ErrorPassthroughRule, len(ps))
- for i := range ps {
- p[i] = ps[i]
+// EdgeCleared returns a boolean which indicates if the edge with the given name
+// was cleared in this mutation.
+func (m *AuthIdentityMutation) EdgeCleared(name string) bool {
+ switch name {
+ case authidentity.EdgeUser:
+ return m.cleareduser
+ case authidentity.EdgeChannels:
+ return m.clearedchannels
+ case authidentity.EdgeAdoptionDecisions:
+ return m.clearedadoption_decisions
}
- m.Where(p...)
+ return false
}
-// Op returns the operation name.
-func (m *ErrorPassthroughRuleMutation) Op() Op {
- return m.op
+// ClearEdge clears the value of the edge with the given name. It returns an error
+// if that edge is not defined in the schema.
+func (m *AuthIdentityMutation) ClearEdge(name string) error {
+ switch name {
+ case authidentity.EdgeUser:
+ m.ClearUser()
+ return nil
+ }
+ return fmt.Errorf("unknown AuthIdentity unique edge %s", name)
}
-// SetOp allows setting the mutation operation.
-func (m *ErrorPassthroughRuleMutation) SetOp(op Op) {
- m.op = op
+// ResetEdge resets all changes to the edge with the given name in this mutation.
+// It returns an error if the edge is not defined in the schema.
+func (m *AuthIdentityMutation) ResetEdge(name string) error {
+ switch name {
+ case authidentity.EdgeUser:
+ m.ResetUser()
+ return nil
+ case authidentity.EdgeChannels:
+ m.ResetChannels()
+ return nil
+ case authidentity.EdgeAdoptionDecisions:
+ m.ResetAdoptionDecisions()
+ return nil
+ }
+ return fmt.Errorf("unknown AuthIdentity edge %s", name)
}
-// Type returns the node type of this mutation (ErrorPassthroughRule).
-func (m *ErrorPassthroughRuleMutation) Type() string {
- return m.typ
+// AuthIdentityChannelMutation represents an operation that mutates the AuthIdentityChannel nodes in the graph.
+type AuthIdentityChannelMutation struct {
+ config
+ op Op
+ typ string
+ id *int64
+ created_at *time.Time
+ updated_at *time.Time
+ provider_type *string
+ provider_key *string
+ channel *string
+ channel_app_id *string
+ channel_subject *string
+ metadata *map[string]interface{}
+ clearedFields map[string]struct{}
+ identity *int64
+ clearedidentity bool
+ done bool
+ oldValue func(context.Context) (*AuthIdentityChannel, error)
+ predicates []predicate.AuthIdentityChannel
}
-// Fields returns all fields that were changed during this mutation. Note that in
-// order to get all numeric fields that were incremented/decremented, call
-// AddedFields().
-func (m *ErrorPassthroughRuleMutation) Fields() []string {
- fields := make([]string, 0, 15)
- if m.created_at != nil {
- fields = append(fields, errorpassthroughrule.FieldCreatedAt)
+var _ ent.Mutation = (*AuthIdentityChannelMutation)(nil)
+
+// authidentitychannelOption allows management of the mutation configuration using functional options.
+type authidentitychannelOption func(*AuthIdentityChannelMutation)
+
+// newAuthIdentityChannelMutation creates new mutation for the AuthIdentityChannel entity.
+func newAuthIdentityChannelMutation(c config, op Op, opts ...authidentitychannelOption) *AuthIdentityChannelMutation {
+ m := &AuthIdentityChannelMutation{
+ config: c,
+ op: op,
+ typ: TypeAuthIdentityChannel,
+ clearedFields: make(map[string]struct{}),
}
- if m.updated_at != nil {
- fields = append(fields, errorpassthroughrule.FieldUpdatedAt)
+ for _, opt := range opts {
+ opt(m)
}
- if m.name != nil {
- fields = append(fields, errorpassthroughrule.FieldName)
+ return m
+}
+
+// withAuthIdentityChannelID sets the ID field of the mutation.
+func withAuthIdentityChannelID(id int64) authidentitychannelOption {
+ return func(m *AuthIdentityChannelMutation) {
+ var (
+ err error
+ once sync.Once
+ value *AuthIdentityChannel
+ )
+ m.oldValue = func(ctx context.Context) (*AuthIdentityChannel, error) {
+ once.Do(func() {
+ if m.done {
+ err = errors.New("querying old values post mutation is not allowed")
+ } else {
+ value, err = m.Client().AuthIdentityChannel.Get(ctx, id)
+ }
+ })
+ return value, err
+ }
+ m.id = &id
}
- if m.enabled != nil {
- fields = append(fields, errorpassthroughrule.FieldEnabled)
+}
+
+// withAuthIdentityChannel sets the old AuthIdentityChannel of the mutation.
+func withAuthIdentityChannel(node *AuthIdentityChannel) authidentitychannelOption {
+ return func(m *AuthIdentityChannelMutation) {
+ m.oldValue = func(context.Context) (*AuthIdentityChannel, error) {
+ return node, nil
+ }
+ m.id = &node.ID
}
- if m.priority != nil {
- fields = append(fields, errorpassthroughrule.FieldPriority)
+}
+
+// Client returns a new `ent.Client` from the mutation. If the mutation was
+// executed in a transaction (ent.Tx), a transactional client is returned.
+func (m AuthIdentityChannelMutation) Client() *Client {
+ client := &Client{config: m.config}
+ client.init()
+ return client
+}
+
+// Tx returns an `ent.Tx` for mutations that were executed in transactions;
+// it returns an error otherwise.
+func (m AuthIdentityChannelMutation) Tx() (*Tx, error) {
+ if _, ok := m.driver.(*txDriver); !ok {
+ return nil, errors.New("ent: mutation is not running in a transaction")
}
- if m.error_codes != nil {
- fields = append(fields, errorpassthroughrule.FieldErrorCodes)
+ tx := &Tx{config: m.config}
+ tx.init()
+ return tx, nil
+}
+
+// ID returns the ID value in the mutation. Note that the ID is only available
+// if it was provided to the builder or after it was returned from the database.
+func (m *AuthIdentityChannelMutation) ID() (id int64, exists bool) {
+ if m.id == nil {
+ return
}
- if m.keywords != nil {
- fields = append(fields, errorpassthroughrule.FieldKeywords)
+ return *m.id, true
+}
+
+// IDs queries the database and returns the entity ids that match the mutation's predicate.
+// That means, if the mutation is applied within a transaction with an isolation level such
+// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
+// or updated by the mutation.
+func (m *AuthIdentityChannelMutation) IDs(ctx context.Context) ([]int64, error) {
+ switch {
+ case m.op.Is(OpUpdateOne | OpDeleteOne):
+ id, exists := m.ID()
+ if exists {
+ return []int64{id}, nil
+ }
+ fallthrough
+ case m.op.Is(OpUpdate | OpDelete):
+ return m.Client().AuthIdentityChannel.Query().Where(m.predicates...).IDs(ctx)
+ default:
+ return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
}
- if m.match_mode != nil {
- fields = append(fields, errorpassthroughrule.FieldMatchMode)
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (m *AuthIdentityChannelMutation) SetCreatedAt(t time.Time) {
+ m.created_at = &t
+}
+
+// CreatedAt returns the value of the "created_at" field in the mutation.
+func (m *AuthIdentityChannelMutation) CreatedAt() (r time.Time, exists bool) {
+ v := m.created_at
+ if v == nil {
+ return
}
- if m.platforms != nil {
- fields = append(fields, errorpassthroughrule.FieldPlatforms)
+ return *v, true
+}
+
+// OldCreatedAt returns the old "created_at" field's value of the AuthIdentityChannel entity.
+// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityChannelMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations")
}
- if m.passthrough_code != nil {
- fields = append(fields, errorpassthroughrule.FieldPassthroughCode)
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCreatedAt requires an ID field in the mutation")
}
- if m.response_code != nil {
- fields = append(fields, errorpassthroughrule.FieldResponseCode)
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err)
}
- if m.passthrough_body != nil {
- fields = append(fields, errorpassthroughrule.FieldPassthroughBody)
+ return oldValue.CreatedAt, nil
+}
+
+// ResetCreatedAt resets all changes to the "created_at" field.
+func (m *AuthIdentityChannelMutation) ResetCreatedAt() {
+ m.created_at = nil
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (m *AuthIdentityChannelMutation) SetUpdatedAt(t time.Time) {
+ m.updated_at = &t
+}
+
+// UpdatedAt returns the value of the "updated_at" field in the mutation.
+func (m *AuthIdentityChannelMutation) UpdatedAt() (r time.Time, exists bool) {
+ v := m.updated_at
+ if v == nil {
+ return
}
- if m.custom_message != nil {
- fields = append(fields, errorpassthroughrule.FieldCustomMessage)
+ return *v, true
+}
+
+// OldUpdatedAt returns the old "updated_at" field's value of the AuthIdentityChannel entity.
+// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityChannelMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations")
}
- if m.skip_monitoring != nil {
- fields = append(fields, errorpassthroughrule.FieldSkipMonitoring)
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldUpdatedAt requires an ID field in the mutation")
}
- if m.description != nil {
- fields = append(fields, errorpassthroughrule.FieldDescription)
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err)
}
- return fields
+ return oldValue.UpdatedAt, nil
}
-// Field returns the value of a field with the given name. The second boolean
-// return value indicates that this field was not set, or was not defined in the
-// schema.
-func (m *ErrorPassthroughRuleMutation) Field(name string) (ent.Value, bool) {
- switch name {
- case errorpassthroughrule.FieldCreatedAt:
- return m.CreatedAt()
- case errorpassthroughrule.FieldUpdatedAt:
- return m.UpdatedAt()
- case errorpassthroughrule.FieldName:
- return m.Name()
- case errorpassthroughrule.FieldEnabled:
- return m.Enabled()
- case errorpassthroughrule.FieldPriority:
- return m.Priority()
- case errorpassthroughrule.FieldErrorCodes:
- return m.ErrorCodes()
- case errorpassthroughrule.FieldKeywords:
- return m.Keywords()
- case errorpassthroughrule.FieldMatchMode:
- return m.MatchMode()
- case errorpassthroughrule.FieldPlatforms:
- return m.Platforms()
- case errorpassthroughrule.FieldPassthroughCode:
- return m.PassthroughCode()
- case errorpassthroughrule.FieldResponseCode:
- return m.ResponseCode()
- case errorpassthroughrule.FieldPassthroughBody:
- return m.PassthroughBody()
- case errorpassthroughrule.FieldCustomMessage:
- return m.CustomMessage()
- case errorpassthroughrule.FieldSkipMonitoring:
- return m.SkipMonitoring()
- case errorpassthroughrule.FieldDescription:
- return m.Description()
+// ResetUpdatedAt resets all changes to the "updated_at" field.
+func (m *AuthIdentityChannelMutation) ResetUpdatedAt() {
+ m.updated_at = nil
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (m *AuthIdentityChannelMutation) SetIdentityID(i int64) {
+ m.identity = &i
+}
+
+// IdentityID returns the value of the "identity_id" field in the mutation.
+func (m *AuthIdentityChannelMutation) IdentityID() (r int64, exists bool) {
+ v := m.identity
+ if v == nil {
+ return
}
- return nil, false
+ return *v, true
}
-// OldField returns the old value of the field from the database. An error is
-// returned if the mutation operation is not UpdateOne, or the query to the
-// database failed.
-func (m *ErrorPassthroughRuleMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
- switch name {
- case errorpassthroughrule.FieldCreatedAt:
- return m.OldCreatedAt(ctx)
- case errorpassthroughrule.FieldUpdatedAt:
- return m.OldUpdatedAt(ctx)
- case errorpassthroughrule.FieldName:
- return m.OldName(ctx)
- case errorpassthroughrule.FieldEnabled:
- return m.OldEnabled(ctx)
- case errorpassthroughrule.FieldPriority:
- return m.OldPriority(ctx)
- case errorpassthroughrule.FieldErrorCodes:
- return m.OldErrorCodes(ctx)
- case errorpassthroughrule.FieldKeywords:
- return m.OldKeywords(ctx)
- case errorpassthroughrule.FieldMatchMode:
- return m.OldMatchMode(ctx)
- case errorpassthroughrule.FieldPlatforms:
- return m.OldPlatforms(ctx)
- case errorpassthroughrule.FieldPassthroughCode:
- return m.OldPassthroughCode(ctx)
- case errorpassthroughrule.FieldResponseCode:
- return m.OldResponseCode(ctx)
- case errorpassthroughrule.FieldPassthroughBody:
- return m.OldPassthroughBody(ctx)
- case errorpassthroughrule.FieldCustomMessage:
- return m.OldCustomMessage(ctx)
- case errorpassthroughrule.FieldSkipMonitoring:
- return m.OldSkipMonitoring(ctx)
- case errorpassthroughrule.FieldDescription:
- return m.OldDescription(ctx)
+// OldIdentityID returns the old "identity_id" field's value of the AuthIdentityChannel entity.
+// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityChannelMutation) OldIdentityID(ctx context.Context) (v int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldIdentityID is only allowed on UpdateOne operations")
}
- return nil, fmt.Errorf("unknown ErrorPassthroughRule field %s", name)
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldIdentityID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldIdentityID: %w", err)
+ }
+ return oldValue.IdentityID, nil
}
-// SetField sets the value of a field with the given name. It returns an error if
-// the field is not defined in the schema, or if the type mismatched the field
-// type.
-func (m *ErrorPassthroughRuleMutation) SetField(name string, value ent.Value) error {
- switch name {
- case errorpassthroughrule.FieldCreatedAt:
- v, ok := value.(time.Time)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetCreatedAt(v)
- return nil
- case errorpassthroughrule.FieldUpdatedAt:
- v, ok := value.(time.Time)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetUpdatedAt(v)
- return nil
- case errorpassthroughrule.FieldName:
- v, ok := value.(string)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetName(v)
- return nil
- case errorpassthroughrule.FieldEnabled:
- v, ok := value.(bool)
+// ResetIdentityID resets all changes to the "identity_id" field.
+func (m *AuthIdentityChannelMutation) ResetIdentityID() {
+ m.identity = nil
+}
+
+// SetProviderType sets the "provider_type" field.
+func (m *AuthIdentityChannelMutation) SetProviderType(s string) {
+ m.provider_type = &s
+}
+
+// ProviderType returns the value of the "provider_type" field in the mutation.
+func (m *AuthIdentityChannelMutation) ProviderType() (r string, exists bool) {
+ v := m.provider_type
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderType returns the old "provider_type" field's value of the AuthIdentityChannel entity.
+// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityChannelMutation) OldProviderType(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderType is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderType requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderType: %w", err)
+ }
+ return oldValue.ProviderType, nil
+}
+
+// ResetProviderType resets all changes to the "provider_type" field.
+func (m *AuthIdentityChannelMutation) ResetProviderType() {
+ m.provider_type = nil
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (m *AuthIdentityChannelMutation) SetProviderKey(s string) {
+ m.provider_key = &s
+}
+
+// ProviderKey returns the value of the "provider_key" field in the mutation.
+func (m *AuthIdentityChannelMutation) ProviderKey() (r string, exists bool) {
+ v := m.provider_key
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderKey returns the old "provider_key" field's value of the AuthIdentityChannel entity.
+// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityChannelMutation) OldProviderKey(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderKey is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderKey requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderKey: %w", err)
+ }
+ return oldValue.ProviderKey, nil
+}
+
+// ResetProviderKey resets all changes to the "provider_key" field.
+func (m *AuthIdentityChannelMutation) ResetProviderKey() {
+ m.provider_key = nil
+}
+
+// SetChannel sets the "channel" field.
+func (m *AuthIdentityChannelMutation) SetChannel(s string) {
+ m.channel = &s
+}
+
+// Channel returns the value of the "channel" field in the mutation.
+func (m *AuthIdentityChannelMutation) Channel() (r string, exists bool) {
+ v := m.channel
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldChannel returns the old "channel" field's value of the AuthIdentityChannel entity.
+// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityChannelMutation) OldChannel(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldChannel is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldChannel requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldChannel: %w", err)
+ }
+ return oldValue.Channel, nil
+}
+
+// ResetChannel resets all changes to the "channel" field.
+func (m *AuthIdentityChannelMutation) ResetChannel() {
+ m.channel = nil
+}
+
+// SetChannelAppID sets the "channel_app_id" field.
+func (m *AuthIdentityChannelMutation) SetChannelAppID(s string) {
+ m.channel_app_id = &s
+}
+
+// ChannelAppID returns the value of the "channel_app_id" field in the mutation.
+func (m *AuthIdentityChannelMutation) ChannelAppID() (r string, exists bool) {
+ v := m.channel_app_id
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldChannelAppID returns the old "channel_app_id" field's value of the AuthIdentityChannel entity.
+// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityChannelMutation) OldChannelAppID(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldChannelAppID is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldChannelAppID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldChannelAppID: %w", err)
+ }
+ return oldValue.ChannelAppID, nil
+}
+
+// ResetChannelAppID resets all changes to the "channel_app_id" field.
+func (m *AuthIdentityChannelMutation) ResetChannelAppID() {
+ m.channel_app_id = nil
+}
+
+// SetChannelSubject sets the "channel_subject" field.
+func (m *AuthIdentityChannelMutation) SetChannelSubject(s string) {
+ m.channel_subject = &s
+}
+
+// ChannelSubject returns the value of the "channel_subject" field in the mutation.
+func (m *AuthIdentityChannelMutation) ChannelSubject() (r string, exists bool) {
+ v := m.channel_subject
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldChannelSubject returns the old "channel_subject" field's value of the AuthIdentityChannel entity.
+// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityChannelMutation) OldChannelSubject(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldChannelSubject is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldChannelSubject requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldChannelSubject: %w", err)
+ }
+ return oldValue.ChannelSubject, nil
+}
+
+// ResetChannelSubject resets all changes to the "channel_subject" field.
+func (m *AuthIdentityChannelMutation) ResetChannelSubject() {
+ m.channel_subject = nil
+}
+
+// SetMetadata sets the "metadata" field.
+func (m *AuthIdentityChannelMutation) SetMetadata(value map[string]interface{}) {
+ m.metadata = &value
+}
+
+// Metadata returns the value of the "metadata" field in the mutation.
+func (m *AuthIdentityChannelMutation) Metadata() (r map[string]interface{}, exists bool) {
+ v := m.metadata
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldMetadata returns the old "metadata" field's value of the AuthIdentityChannel entity.
+// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityChannelMutation) OldMetadata(ctx context.Context) (v map[string]interface{}, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldMetadata is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldMetadata requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldMetadata: %w", err)
+ }
+ return oldValue.Metadata, nil
+}
+
+// ResetMetadata resets all changes to the "metadata" field.
+func (m *AuthIdentityChannelMutation) ResetMetadata() {
+ m.metadata = nil
+}
+
+// ClearIdentity clears the "identity" edge to the AuthIdentity entity.
+func (m *AuthIdentityChannelMutation) ClearIdentity() {
+ m.clearedidentity = true
+ m.clearedFields[authidentitychannel.FieldIdentityID] = struct{}{}
+}
+
+// IdentityCleared reports if the "identity" edge to the AuthIdentity entity was cleared.
+func (m *AuthIdentityChannelMutation) IdentityCleared() bool {
+ return m.clearedidentity
+}
+
+// IdentityIDs returns the "identity" edge IDs in the mutation.
+// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use
+// IdentityID instead. It exists only for internal usage by the builders.
+func (m *AuthIdentityChannelMutation) IdentityIDs() (ids []int64) {
+ if id := m.identity; id != nil {
+ ids = append(ids, *id)
+ }
+ return
+}
+
+// ResetIdentity resets all changes to the "identity" edge.
+func (m *AuthIdentityChannelMutation) ResetIdentity() {
+ m.identity = nil
+ m.clearedidentity = false
+}
+
+// Where appends a list predicates to the AuthIdentityChannelMutation builder.
+func (m *AuthIdentityChannelMutation) Where(ps ...predicate.AuthIdentityChannel) {
+ m.predicates = append(m.predicates, ps...)
+}
+
+// WhereP appends storage-level predicates to the AuthIdentityChannelMutation builder. Using this method,
+// users can use type-assertion to append predicates that do not depend on any generated package.
+func (m *AuthIdentityChannelMutation) WhereP(ps ...func(*sql.Selector)) {
+ p := make([]predicate.AuthIdentityChannel, len(ps))
+ for i := range ps {
+ p[i] = ps[i]
+ }
+ m.Where(p...)
+}
+
+// Op returns the operation name.
+func (m *AuthIdentityChannelMutation) Op() Op {
+ return m.op
+}
+
+// SetOp allows setting the mutation operation.
+func (m *AuthIdentityChannelMutation) SetOp(op Op) {
+ m.op = op
+}
+
+// Type returns the node type of this mutation (AuthIdentityChannel).
+func (m *AuthIdentityChannelMutation) Type() string {
+ return m.typ
+}
+
+// Fields returns all fields that were changed during this mutation. Note that in
+// order to get all numeric fields that were incremented/decremented, call
+// AddedFields().
+func (m *AuthIdentityChannelMutation) Fields() []string {
+ fields := make([]string, 0, 9)
+ if m.created_at != nil {
+ fields = append(fields, authidentitychannel.FieldCreatedAt)
+ }
+ if m.updated_at != nil {
+ fields = append(fields, authidentitychannel.FieldUpdatedAt)
+ }
+ if m.identity != nil {
+ fields = append(fields, authidentitychannel.FieldIdentityID)
+ }
+ if m.provider_type != nil {
+ fields = append(fields, authidentitychannel.FieldProviderType)
+ }
+ if m.provider_key != nil {
+ fields = append(fields, authidentitychannel.FieldProviderKey)
+ }
+ if m.channel != nil {
+ fields = append(fields, authidentitychannel.FieldChannel)
+ }
+ if m.channel_app_id != nil {
+ fields = append(fields, authidentitychannel.FieldChannelAppID)
+ }
+ if m.channel_subject != nil {
+ fields = append(fields, authidentitychannel.FieldChannelSubject)
+ }
+ if m.metadata != nil {
+ fields = append(fields, authidentitychannel.FieldMetadata)
+ }
+ return fields
+}
+
+// Field returns the value of a field with the given name. The second boolean
+// return value indicates that this field was not set, or was not defined in the
+// schema.
+func (m *AuthIdentityChannelMutation) Field(name string) (ent.Value, bool) {
+ switch name {
+ case authidentitychannel.FieldCreatedAt:
+ return m.CreatedAt()
+ case authidentitychannel.FieldUpdatedAt:
+ return m.UpdatedAt()
+ case authidentitychannel.FieldIdentityID:
+ return m.IdentityID()
+ case authidentitychannel.FieldProviderType:
+ return m.ProviderType()
+ case authidentitychannel.FieldProviderKey:
+ return m.ProviderKey()
+ case authidentitychannel.FieldChannel:
+ return m.Channel()
+ case authidentitychannel.FieldChannelAppID:
+ return m.ChannelAppID()
+ case authidentitychannel.FieldChannelSubject:
+ return m.ChannelSubject()
+ case authidentitychannel.FieldMetadata:
+ return m.Metadata()
+ }
+ return nil, false
+}
+
+// OldField returns the old value of the field from the database. An error is
+// returned if the mutation operation is not UpdateOne, or the query to the
+// database failed.
+func (m *AuthIdentityChannelMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
+ switch name {
+ case authidentitychannel.FieldCreatedAt:
+ return m.OldCreatedAt(ctx)
+ case authidentitychannel.FieldUpdatedAt:
+ return m.OldUpdatedAt(ctx)
+ case authidentitychannel.FieldIdentityID:
+ return m.OldIdentityID(ctx)
+ case authidentitychannel.FieldProviderType:
+ return m.OldProviderType(ctx)
+ case authidentitychannel.FieldProviderKey:
+ return m.OldProviderKey(ctx)
+ case authidentitychannel.FieldChannel:
+ return m.OldChannel(ctx)
+ case authidentitychannel.FieldChannelAppID:
+ return m.OldChannelAppID(ctx)
+ case authidentitychannel.FieldChannelSubject:
+ return m.OldChannelSubject(ctx)
+ case authidentitychannel.FieldMetadata:
+ return m.OldMetadata(ctx)
+ }
+ return nil, fmt.Errorf("unknown AuthIdentityChannel field %s", name)
+}
+
+// SetField sets the value of a field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *AuthIdentityChannelMutation) SetField(name string, value ent.Value) error {
+ switch name {
+ case authidentitychannel.FieldCreatedAt:
+ v, ok := value.(time.Time)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
- m.SetEnabled(v)
+ m.SetCreatedAt(v)
return nil
- case errorpassthroughrule.FieldPriority:
- v, ok := value.(int)
+ case authidentitychannel.FieldUpdatedAt:
+ v, ok := value.(time.Time)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
- m.SetPriority(v)
+ m.SetUpdatedAt(v)
return nil
- case errorpassthroughrule.FieldErrorCodes:
- v, ok := value.([]int)
+ case authidentitychannel.FieldIdentityID:
+ v, ok := value.(int64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
- m.SetErrorCodes(v)
+ m.SetIdentityID(v)
return nil
- case errorpassthroughrule.FieldKeywords:
- v, ok := value.([]string)
+ case authidentitychannel.FieldProviderType:
+ v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
- m.SetKeywords(v)
+ m.SetProviderType(v)
return nil
- case errorpassthroughrule.FieldMatchMode:
+ case authidentitychannel.FieldProviderKey:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
- m.SetMatchMode(v)
+ m.SetProviderKey(v)
return nil
- case errorpassthroughrule.FieldPlatforms:
- v, ok := value.([]string)
+ case authidentitychannel.FieldChannel:
+ v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
- m.SetPlatforms(v)
+ m.SetChannel(v)
return nil
- case errorpassthroughrule.FieldPassthroughCode:
- v, ok := value.(bool)
+ case authidentitychannel.FieldChannelAppID:
+ v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
- m.SetPassthroughCode(v)
+ m.SetChannelAppID(v)
return nil
- case errorpassthroughrule.FieldResponseCode:
- v, ok := value.(int)
+ case authidentitychannel.FieldChannelSubject:
+ v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
- m.SetResponseCode(v)
+ m.SetChannelSubject(v)
return nil
- case errorpassthroughrule.FieldPassthroughBody:
- v, ok := value.(bool)
+ case authidentitychannel.FieldMetadata:
+ v, ok := value.(map[string]interface{})
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
- m.SetPassthroughBody(v)
- return nil
- case errorpassthroughrule.FieldCustomMessage:
- v, ok := value.(string)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetCustomMessage(v)
- return nil
- case errorpassthroughrule.FieldSkipMonitoring:
- v, ok := value.(bool)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetSkipMonitoring(v)
- return nil
- case errorpassthroughrule.FieldDescription:
- v, ok := value.(string)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetDescription(v)
+ m.SetMetadata(v)
return nil
}
- return fmt.Errorf("unknown ErrorPassthroughRule field %s", name)
+ return fmt.Errorf("unknown AuthIdentityChannel field %s", name)
}
// AddedFields returns all numeric fields that were incremented/decremented during
// this mutation.
-func (m *ErrorPassthroughRuleMutation) AddedFields() []string {
+func (m *AuthIdentityChannelMutation) AddedFields() []string {
var fields []string
- if m.addpriority != nil {
- fields = append(fields, errorpassthroughrule.FieldPriority)
- }
- if m.addresponse_code != nil {
- fields = append(fields, errorpassthroughrule.FieldResponseCode)
- }
return fields
}
// AddedField returns the numeric value that was incremented/decremented on a field
// with the given name. The second boolean return value indicates that this field
// was not set, or was not defined in the schema.
-func (m *ErrorPassthroughRuleMutation) AddedField(name string) (ent.Value, bool) {
+func (m *AuthIdentityChannelMutation) AddedField(name string) (ent.Value, bool) {
switch name {
- case errorpassthroughrule.FieldPriority:
- return m.AddedPriority()
- case errorpassthroughrule.FieldResponseCode:
- return m.AddedResponseCode()
}
return nil, false
}
@@ -8028,290 +8600,205 @@ func (m *ErrorPassthroughRuleMutation) AddedField(name string) (ent.Value, bool)
// AddField adds the value to the field with the given name. It returns an error if
// the field is not defined in the schema, or if the type mismatched the field
// type.
-func (m *ErrorPassthroughRuleMutation) AddField(name string, value ent.Value) error {
+func (m *AuthIdentityChannelMutation) AddField(name string, value ent.Value) error {
switch name {
- case errorpassthroughrule.FieldPriority:
- v, ok := value.(int)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.AddPriority(v)
- return nil
- case errorpassthroughrule.FieldResponseCode:
- v, ok := value.(int)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.AddResponseCode(v)
- return nil
}
- return fmt.Errorf("unknown ErrorPassthroughRule numeric field %s", name)
+ return fmt.Errorf("unknown AuthIdentityChannel numeric field %s", name)
}
// ClearedFields returns all nullable fields that were cleared during this
// mutation.
-func (m *ErrorPassthroughRuleMutation) ClearedFields() []string {
- var fields []string
- if m.FieldCleared(errorpassthroughrule.FieldErrorCodes) {
- fields = append(fields, errorpassthroughrule.FieldErrorCodes)
- }
- if m.FieldCleared(errorpassthroughrule.FieldKeywords) {
- fields = append(fields, errorpassthroughrule.FieldKeywords)
- }
- if m.FieldCleared(errorpassthroughrule.FieldPlatforms) {
- fields = append(fields, errorpassthroughrule.FieldPlatforms)
- }
- if m.FieldCleared(errorpassthroughrule.FieldResponseCode) {
- fields = append(fields, errorpassthroughrule.FieldResponseCode)
- }
- if m.FieldCleared(errorpassthroughrule.FieldCustomMessage) {
- fields = append(fields, errorpassthroughrule.FieldCustomMessage)
- }
- if m.FieldCleared(errorpassthroughrule.FieldDescription) {
- fields = append(fields, errorpassthroughrule.FieldDescription)
- }
- return fields
+func (m *AuthIdentityChannelMutation) ClearedFields() []string {
+ return nil
}
// FieldCleared returns a boolean indicating if a field with the given name was
// cleared in this mutation.
-func (m *ErrorPassthroughRuleMutation) FieldCleared(name string) bool {
+func (m *AuthIdentityChannelMutation) FieldCleared(name string) bool {
_, ok := m.clearedFields[name]
return ok
}
// ClearField clears the value of the field with the given name. It returns an
// error if the field is not defined in the schema.
-func (m *ErrorPassthroughRuleMutation) ClearField(name string) error {
- switch name {
- case errorpassthroughrule.FieldErrorCodes:
- m.ClearErrorCodes()
- return nil
- case errorpassthroughrule.FieldKeywords:
- m.ClearKeywords()
- return nil
- case errorpassthroughrule.FieldPlatforms:
- m.ClearPlatforms()
- return nil
- case errorpassthroughrule.FieldResponseCode:
- m.ClearResponseCode()
- return nil
- case errorpassthroughrule.FieldCustomMessage:
- m.ClearCustomMessage()
- return nil
- case errorpassthroughrule.FieldDescription:
- m.ClearDescription()
- return nil
- }
- return fmt.Errorf("unknown ErrorPassthroughRule nullable field %s", name)
+func (m *AuthIdentityChannelMutation) ClearField(name string) error {
+ return fmt.Errorf("unknown AuthIdentityChannel nullable field %s", name)
}
// ResetField resets all changes in the mutation for the field with the given name.
// It returns an error if the field is not defined in the schema.
-func (m *ErrorPassthroughRuleMutation) ResetField(name string) error {
+func (m *AuthIdentityChannelMutation) ResetField(name string) error {
switch name {
- case errorpassthroughrule.FieldCreatedAt:
+ case authidentitychannel.FieldCreatedAt:
m.ResetCreatedAt()
return nil
- case errorpassthroughrule.FieldUpdatedAt:
+ case authidentitychannel.FieldUpdatedAt:
m.ResetUpdatedAt()
return nil
- case errorpassthroughrule.FieldName:
- m.ResetName()
- return nil
- case errorpassthroughrule.FieldEnabled:
- m.ResetEnabled()
- return nil
- case errorpassthroughrule.FieldPriority:
- m.ResetPriority()
- return nil
- case errorpassthroughrule.FieldErrorCodes:
- m.ResetErrorCodes()
- return nil
- case errorpassthroughrule.FieldKeywords:
- m.ResetKeywords()
- return nil
- case errorpassthroughrule.FieldMatchMode:
- m.ResetMatchMode()
- return nil
- case errorpassthroughrule.FieldPlatforms:
- m.ResetPlatforms()
+ case authidentitychannel.FieldIdentityID:
+ m.ResetIdentityID()
return nil
- case errorpassthroughrule.FieldPassthroughCode:
- m.ResetPassthroughCode()
+ case authidentitychannel.FieldProviderType:
+ m.ResetProviderType()
return nil
- case errorpassthroughrule.FieldResponseCode:
- m.ResetResponseCode()
+ case authidentitychannel.FieldProviderKey:
+ m.ResetProviderKey()
return nil
- case errorpassthroughrule.FieldPassthroughBody:
- m.ResetPassthroughBody()
+ case authidentitychannel.FieldChannel:
+ m.ResetChannel()
return nil
- case errorpassthroughrule.FieldCustomMessage:
- m.ResetCustomMessage()
+ case authidentitychannel.FieldChannelAppID:
+ m.ResetChannelAppID()
return nil
- case errorpassthroughrule.FieldSkipMonitoring:
- m.ResetSkipMonitoring()
+ case authidentitychannel.FieldChannelSubject:
+ m.ResetChannelSubject()
return nil
- case errorpassthroughrule.FieldDescription:
- m.ResetDescription()
+ case authidentitychannel.FieldMetadata:
+ m.ResetMetadata()
return nil
}
- return fmt.Errorf("unknown ErrorPassthroughRule field %s", name)
+ return fmt.Errorf("unknown AuthIdentityChannel field %s", name)
}
// AddedEdges returns all edge names that were set/added in this mutation.
-func (m *ErrorPassthroughRuleMutation) AddedEdges() []string {
- edges := make([]string, 0, 0)
+func (m *AuthIdentityChannelMutation) AddedEdges() []string {
+ edges := make([]string, 0, 1)
+ if m.identity != nil {
+ edges = append(edges, authidentitychannel.EdgeIdentity)
+ }
return edges
}
// AddedIDs returns all IDs (to other nodes) that were added for the given edge
// name in this mutation.
-func (m *ErrorPassthroughRuleMutation) AddedIDs(name string) []ent.Value {
+func (m *AuthIdentityChannelMutation) AddedIDs(name string) []ent.Value {
+ switch name {
+ case authidentitychannel.EdgeIdentity:
+ if id := m.identity; id != nil {
+ return []ent.Value{*id}
+ }
+ }
return nil
}
// RemovedEdges returns all edge names that were removed in this mutation.
-func (m *ErrorPassthroughRuleMutation) RemovedEdges() []string {
- edges := make([]string, 0, 0)
+func (m *AuthIdentityChannelMutation) RemovedEdges() []string {
+ edges := make([]string, 0, 1)
return edges
}
// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
// the given name in this mutation.
-func (m *ErrorPassthroughRuleMutation) RemovedIDs(name string) []ent.Value {
+func (m *AuthIdentityChannelMutation) RemovedIDs(name string) []ent.Value {
return nil
}
// ClearedEdges returns all edge names that were cleared in this mutation.
-func (m *ErrorPassthroughRuleMutation) ClearedEdges() []string {
- edges := make([]string, 0, 0)
+func (m *AuthIdentityChannelMutation) ClearedEdges() []string {
+ edges := make([]string, 0, 1)
+ if m.clearedidentity {
+ edges = append(edges, authidentitychannel.EdgeIdentity)
+ }
return edges
}
// EdgeCleared returns a boolean which indicates if the edge with the given name
// was cleared in this mutation.
-func (m *ErrorPassthroughRuleMutation) EdgeCleared(name string) bool {
+func (m *AuthIdentityChannelMutation) EdgeCleared(name string) bool {
+ switch name {
+ case authidentitychannel.EdgeIdentity:
+ return m.clearedidentity
+ }
return false
}
// ClearEdge clears the value of the edge with the given name. It returns an error
// if that edge is not defined in the schema.
-func (m *ErrorPassthroughRuleMutation) ClearEdge(name string) error {
- return fmt.Errorf("unknown ErrorPassthroughRule unique edge %s", name)
+func (m *AuthIdentityChannelMutation) ClearEdge(name string) error {
+ switch name {
+ case authidentitychannel.EdgeIdentity:
+ m.ClearIdentity()
+ return nil
+ }
+ return fmt.Errorf("unknown AuthIdentityChannel unique edge %s", name)
}
// ResetEdge resets all changes to the edge with the given name in this mutation.
// It returns an error if the edge is not defined in the schema.
-func (m *ErrorPassthroughRuleMutation) ResetEdge(name string) error {
- return fmt.Errorf("unknown ErrorPassthroughRule edge %s", name)
-}
-
-// GroupMutation represents an operation that mutates the Group nodes in the graph.
-type GroupMutation struct {
- config
- op Op
- typ string
- id *int64
- created_at *time.Time
- updated_at *time.Time
- deleted_at *time.Time
- name *string
- description *string
- rate_multiplier *float64
- addrate_multiplier *float64
- is_exclusive *bool
- status *string
- platform *string
- subscription_type *string
- daily_limit_usd *float64
- adddaily_limit_usd *float64
- weekly_limit_usd *float64
- addweekly_limit_usd *float64
- monthly_limit_usd *float64
- addmonthly_limit_usd *float64
- default_validity_days *int
- adddefault_validity_days *int
- image_price_1k *float64
- addimage_price_1k *float64
- image_price_2k *float64
- addimage_price_2k *float64
- image_price_4k *float64
- addimage_price_4k *float64
- claude_code_only *bool
- fallback_group_id *int64
- addfallback_group_id *int64
- fallback_group_id_on_invalid_request *int64
- addfallback_group_id_on_invalid_request *int64
- model_routing *map[string][]int64
- model_routing_enabled *bool
- mcp_xml_inject *bool
- supported_model_scopes *[]string
- appendsupported_model_scopes []string
- sort_order *int
- addsort_order *int
- allow_messages_dispatch *bool
- require_oauth_only *bool
- require_privacy_set *bool
- default_mapped_model *string
- messages_dispatch_model_config *domain.OpenAIMessagesDispatchModelConfig
- clearedFields map[string]struct{}
- api_keys map[int64]struct{}
- removedapi_keys map[int64]struct{}
- clearedapi_keys bool
- redeem_codes map[int64]struct{}
- removedredeem_codes map[int64]struct{}
- clearedredeem_codes bool
- subscriptions map[int64]struct{}
- removedsubscriptions map[int64]struct{}
- clearedsubscriptions bool
- usage_logs map[int64]struct{}
- removedusage_logs map[int64]struct{}
- clearedusage_logs bool
- accounts map[int64]struct{}
- removedaccounts map[int64]struct{}
- clearedaccounts bool
- allowed_users map[int64]struct{}
- removedallowed_users map[int64]struct{}
- clearedallowed_users bool
- done bool
- oldValue func(context.Context) (*Group, error)
- predicates []predicate.Group
-}
-
-var _ ent.Mutation = (*GroupMutation)(nil)
-
-// groupOption allows management of the mutation configuration using functional options.
-type groupOption func(*GroupMutation)
-
-// newGroupMutation creates new mutation for the Group entity.
-func newGroupMutation(c config, op Op, opts ...groupOption) *GroupMutation {
- m := &GroupMutation{
- config: c,
- op: op,
- typ: TypeGroup,
- clearedFields: make(map[string]struct{}),
- }
- for _, opt := range opts {
- opt(m)
+func (m *AuthIdentityChannelMutation) ResetEdge(name string) error {
+ switch name {
+ case authidentitychannel.EdgeIdentity:
+ m.ResetIdentity()
+ return nil
}
- return m
+ return fmt.Errorf("unknown AuthIdentityChannel edge %s", name)
}
-// withGroupID sets the ID field of the mutation.
-func withGroupID(id int64) groupOption {
- return func(m *GroupMutation) {
- var (
- err error
- once sync.Once
- value *Group
+// ErrorPassthroughRuleMutation represents an operation that mutates the ErrorPassthroughRule nodes in the graph.
+type ErrorPassthroughRuleMutation struct {
+ config
+ op Op
+ typ string
+ id *int64
+ created_at *time.Time
+ updated_at *time.Time
+ name *string
+ enabled *bool
+ priority *int
+ addpriority *int
+ error_codes *[]int
+ appenderror_codes []int
+ keywords *[]string
+ appendkeywords []string
+ match_mode *string
+ platforms *[]string
+ appendplatforms []string
+ passthrough_code *bool
+ response_code *int
+ addresponse_code *int
+ passthrough_body *bool
+ custom_message *string
+ skip_monitoring *bool
+ description *string
+ clearedFields map[string]struct{}
+ done bool
+ oldValue func(context.Context) (*ErrorPassthroughRule, error)
+ predicates []predicate.ErrorPassthroughRule
+}
+
+var _ ent.Mutation = (*ErrorPassthroughRuleMutation)(nil)
+
+// errorpassthroughruleOption allows management of the mutation configuration using functional options.
+type errorpassthroughruleOption func(*ErrorPassthroughRuleMutation)
+
+// newErrorPassthroughRuleMutation creates new mutation for the ErrorPassthroughRule entity.
+func newErrorPassthroughRuleMutation(c config, op Op, opts ...errorpassthroughruleOption) *ErrorPassthroughRuleMutation {
+ m := &ErrorPassthroughRuleMutation{
+ config: c,
+ op: op,
+ typ: TypeErrorPassthroughRule,
+ clearedFields: make(map[string]struct{}),
+ }
+ for _, opt := range opts {
+ opt(m)
+ }
+ return m
+}
+
+// withErrorPassthroughRuleID sets the ID field of the mutation.
+func withErrorPassthroughRuleID(id int64) errorpassthroughruleOption {
+ return func(m *ErrorPassthroughRuleMutation) {
+ var (
+ err error
+ once sync.Once
+ value *ErrorPassthroughRule
)
- m.oldValue = func(ctx context.Context) (*Group, error) {
+ m.oldValue = func(ctx context.Context) (*ErrorPassthroughRule, error) {
once.Do(func() {
if m.done {
err = errors.New("querying old values post mutation is not allowed")
} else {
- value, err = m.Client().Group.Get(ctx, id)
+ value, err = m.Client().ErrorPassthroughRule.Get(ctx, id)
}
})
return value, err
@@ -8320,10 +8807,10 @@ func withGroupID(id int64) groupOption {
}
}
-// withGroup sets the old Group of the mutation.
-func withGroup(node *Group) groupOption {
- return func(m *GroupMutation) {
- m.oldValue = func(context.Context) (*Group, error) {
+// withErrorPassthroughRule sets the old ErrorPassthroughRule of the mutation.
+func withErrorPassthroughRule(node *ErrorPassthroughRule) errorpassthroughruleOption {
+ return func(m *ErrorPassthroughRuleMutation) {
+ m.oldValue = func(context.Context) (*ErrorPassthroughRule, error) {
return node, nil
}
m.id = &node.ID
@@ -8332,7 +8819,7 @@ func withGroup(node *Group) groupOption {
// Client returns a new `ent.Client` from the mutation. If the mutation was
// executed in a transaction (ent.Tx), a transactional client is returned.
-func (m GroupMutation) Client() *Client {
+func (m ErrorPassthroughRuleMutation) Client() *Client {
client := &Client{config: m.config}
client.init()
return client
@@ -8340,7 +8827,7 @@ func (m GroupMutation) Client() *Client {
// Tx returns an `ent.Tx` for mutations that were executed in transactions;
// it returns an error otherwise.
-func (m GroupMutation) Tx() (*Tx, error) {
+func (m ErrorPassthroughRuleMutation) Tx() (*Tx, error) {
if _, ok := m.driver.(*txDriver); !ok {
return nil, errors.New("ent: mutation is not running in a transaction")
}
@@ -8351,7 +8838,7 @@ func (m GroupMutation) Tx() (*Tx, error) {
// ID returns the ID value in the mutation. Note that the ID is only available
// if it was provided to the builder or after it was returned from the database.
-func (m *GroupMutation) ID() (id int64, exists bool) {
+func (m *ErrorPassthroughRuleMutation) ID() (id int64, exists bool) {
if m.id == nil {
return
}
@@ -8362,7 +8849,7 @@ func (m *GroupMutation) ID() (id int64, exists bool) {
// That means, if the mutation is applied within a transaction with an isolation level such
// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
// or updated by the mutation.
-func (m *GroupMutation) IDs(ctx context.Context) ([]int64, error) {
+func (m *ErrorPassthroughRuleMutation) IDs(ctx context.Context) ([]int64, error) {
switch {
case m.op.Is(OpUpdateOne | OpDeleteOne):
id, exists := m.ID()
@@ -8371,19 +8858,19 @@ func (m *GroupMutation) IDs(ctx context.Context) ([]int64, error) {
}
fallthrough
case m.op.Is(OpUpdate | OpDelete):
- return m.Client().Group.Query().Where(m.predicates...).IDs(ctx)
+ return m.Client().ErrorPassthroughRule.Query().Where(m.predicates...).IDs(ctx)
default:
return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
}
}
// SetCreatedAt sets the "created_at" field.
-func (m *GroupMutation) SetCreatedAt(t time.Time) {
+func (m *ErrorPassthroughRuleMutation) SetCreatedAt(t time.Time) {
m.created_at = &t
}
// CreatedAt returns the value of the "created_at" field in the mutation.
-func (m *GroupMutation) CreatedAt() (r time.Time, exists bool) {
+func (m *ErrorPassthroughRuleMutation) CreatedAt() (r time.Time, exists bool) {
v := m.created_at
if v == nil {
return
@@ -8391,10 +8878,10 @@ func (m *GroupMutation) CreatedAt() (r time.Time, exists bool) {
return *v, true
}
-// OldCreatedAt returns the old "created_at" field's value of the Group entity.
-// If the Group object wasn't provided to the builder, the object is fetched from the database.
+// OldCreatedAt returns the old "created_at" field's value of the ErrorPassthroughRule entity.
+// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *GroupMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
+func (m *ErrorPassthroughRuleMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations")
}
@@ -8409,17 +8896,17 @@ func (m *GroupMutation) OldCreatedAt(ctx context.Context) (v time.Time, err erro
}
// ResetCreatedAt resets all changes to the "created_at" field.
-func (m *GroupMutation) ResetCreatedAt() {
+func (m *ErrorPassthroughRuleMutation) ResetCreatedAt() {
m.created_at = nil
}
// SetUpdatedAt sets the "updated_at" field.
-func (m *GroupMutation) SetUpdatedAt(t time.Time) {
+func (m *ErrorPassthroughRuleMutation) SetUpdatedAt(t time.Time) {
m.updated_at = &t
}
// UpdatedAt returns the value of the "updated_at" field in the mutation.
-func (m *GroupMutation) UpdatedAt() (r time.Time, exists bool) {
+func (m *ErrorPassthroughRuleMutation) UpdatedAt() (r time.Time, exists bool) {
v := m.updated_at
if v == nil {
return
@@ -8427,10 +8914,10 @@ func (m *GroupMutation) UpdatedAt() (r time.Time, exists bool) {
return *v, true
}
-// OldUpdatedAt returns the old "updated_at" field's value of the Group entity.
-// If the Group object wasn't provided to the builder, the object is fetched from the database.
+// OldUpdatedAt returns the old "updated_at" field's value of the ErrorPassthroughRule entity.
+// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *GroupMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) {
+func (m *ErrorPassthroughRuleMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations")
}
@@ -8445,66 +8932,17 @@ func (m *GroupMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err erro
}
// ResetUpdatedAt resets all changes to the "updated_at" field.
-func (m *GroupMutation) ResetUpdatedAt() {
+func (m *ErrorPassthroughRuleMutation) ResetUpdatedAt() {
m.updated_at = nil
}
-// SetDeletedAt sets the "deleted_at" field.
-func (m *GroupMutation) SetDeletedAt(t time.Time) {
- m.deleted_at = &t
-}
-
-// DeletedAt returns the value of the "deleted_at" field in the mutation.
-func (m *GroupMutation) DeletedAt() (r time.Time, exists bool) {
- v := m.deleted_at
- if v == nil {
- return
- }
- return *v, true
-}
-
-// OldDeletedAt returns the old "deleted_at" field's value of the Group entity.
-// If the Group object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *GroupMutation) OldDeletedAt(ctx context.Context) (v *time.Time, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldDeletedAt is only allowed on UpdateOne operations")
- }
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldDeletedAt requires an ID field in the mutation")
- }
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldDeletedAt: %w", err)
- }
- return oldValue.DeletedAt, nil
-}
-
-// ClearDeletedAt clears the value of the "deleted_at" field.
-func (m *GroupMutation) ClearDeletedAt() {
- m.deleted_at = nil
- m.clearedFields[group.FieldDeletedAt] = struct{}{}
-}
-
-// DeletedAtCleared returns if the "deleted_at" field was cleared in this mutation.
-func (m *GroupMutation) DeletedAtCleared() bool {
- _, ok := m.clearedFields[group.FieldDeletedAt]
- return ok
-}
-
-// ResetDeletedAt resets all changes to the "deleted_at" field.
-func (m *GroupMutation) ResetDeletedAt() {
- m.deleted_at = nil
- delete(m.clearedFields, group.FieldDeletedAt)
-}
-
// SetName sets the "name" field.
-func (m *GroupMutation) SetName(s string) {
+func (m *ErrorPassthroughRuleMutation) SetName(s string) {
m.name = &s
}
// Name returns the value of the "name" field in the mutation.
-func (m *GroupMutation) Name() (r string, exists bool) {
+func (m *ErrorPassthroughRuleMutation) Name() (r string, exists bool) {
v := m.name
if v == nil {
return
@@ -8512,10 +8950,10 @@ func (m *GroupMutation) Name() (r string, exists bool) {
return *v, true
}
-// OldName returns the old "name" field's value of the Group entity.
-// If the Group object wasn't provided to the builder, the object is fetched from the database.
+// OldName returns the old "name" field's value of the ErrorPassthroughRule entity.
+// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *GroupMutation) OldName(ctx context.Context) (v string, err error) {
+func (m *ErrorPassthroughRuleMutation) OldName(ctx context.Context) (v string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldName is only allowed on UpdateOne operations")
}
@@ -8530,3285 +8968,3061 @@ func (m *GroupMutation) OldName(ctx context.Context) (v string, err error) {
}
// ResetName resets all changes to the "name" field.
-func (m *GroupMutation) ResetName() {
+func (m *ErrorPassthroughRuleMutation) ResetName() {
m.name = nil
}
-// SetDescription sets the "description" field.
-func (m *GroupMutation) SetDescription(s string) {
- m.description = &s
+// SetEnabled sets the "enabled" field.
+func (m *ErrorPassthroughRuleMutation) SetEnabled(b bool) {
+ m.enabled = &b
}
-// Description returns the value of the "description" field in the mutation.
-func (m *GroupMutation) Description() (r string, exists bool) {
- v := m.description
+// Enabled returns the value of the "enabled" field in the mutation.
+func (m *ErrorPassthroughRuleMutation) Enabled() (r bool, exists bool) {
+ v := m.enabled
if v == nil {
return
}
return *v, true
}
-// OldDescription returns the old "description" field's value of the Group entity.
-// If the Group object wasn't provided to the builder, the object is fetched from the database.
+// OldEnabled returns the old "enabled" field's value of the ErrorPassthroughRule entity.
+// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *GroupMutation) OldDescription(ctx context.Context) (v *string, err error) {
+func (m *ErrorPassthroughRuleMutation) OldEnabled(ctx context.Context) (v bool, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldDescription is only allowed on UpdateOne operations")
+ return v, errors.New("OldEnabled is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldDescription requires an ID field in the mutation")
+ return v, errors.New("OldEnabled requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldDescription: %w", err)
+ return v, fmt.Errorf("querying old value for OldEnabled: %w", err)
}
- return oldValue.Description, nil
-}
-
-// ClearDescription clears the value of the "description" field.
-func (m *GroupMutation) ClearDescription() {
- m.description = nil
- m.clearedFields[group.FieldDescription] = struct{}{}
-}
-
-// DescriptionCleared returns if the "description" field was cleared in this mutation.
-func (m *GroupMutation) DescriptionCleared() bool {
- _, ok := m.clearedFields[group.FieldDescription]
- return ok
+ return oldValue.Enabled, nil
}
-// ResetDescription resets all changes to the "description" field.
-func (m *GroupMutation) ResetDescription() {
- m.description = nil
- delete(m.clearedFields, group.FieldDescription)
+// ResetEnabled resets all changes to the "enabled" field.
+func (m *ErrorPassthroughRuleMutation) ResetEnabled() {
+ m.enabled = nil
}
-// SetRateMultiplier sets the "rate_multiplier" field.
-func (m *GroupMutation) SetRateMultiplier(f float64) {
- m.rate_multiplier = &f
- m.addrate_multiplier = nil
+// SetPriority sets the "priority" field.
+func (m *ErrorPassthroughRuleMutation) SetPriority(i int) {
+ m.priority = &i
+ m.addpriority = nil
}
-// RateMultiplier returns the value of the "rate_multiplier" field in the mutation.
-func (m *GroupMutation) RateMultiplier() (r float64, exists bool) {
- v := m.rate_multiplier
+// Priority returns the value of the "priority" field in the mutation.
+func (m *ErrorPassthroughRuleMutation) Priority() (r int, exists bool) {
+ v := m.priority
if v == nil {
return
}
return *v, true
}
-// OldRateMultiplier returns the old "rate_multiplier" field's value of the Group entity.
-// If the Group object wasn't provided to the builder, the object is fetched from the database.
+// OldPriority returns the old "priority" field's value of the ErrorPassthroughRule entity.
+// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *GroupMutation) OldRateMultiplier(ctx context.Context) (v float64, err error) {
+func (m *ErrorPassthroughRuleMutation) OldPriority(ctx context.Context) (v int, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldRateMultiplier is only allowed on UpdateOne operations")
+ return v, errors.New("OldPriority is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldRateMultiplier requires an ID field in the mutation")
+ return v, errors.New("OldPriority requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldRateMultiplier: %w", err)
+ return v, fmt.Errorf("querying old value for OldPriority: %w", err)
}
- return oldValue.RateMultiplier, nil
+ return oldValue.Priority, nil
}
-// AddRateMultiplier adds f to the "rate_multiplier" field.
-func (m *GroupMutation) AddRateMultiplier(f float64) {
- if m.addrate_multiplier != nil {
- *m.addrate_multiplier += f
+// AddPriority adds i to the "priority" field.
+func (m *ErrorPassthroughRuleMutation) AddPriority(i int) {
+ if m.addpriority != nil {
+ *m.addpriority += i
} else {
- m.addrate_multiplier = &f
+ m.addpriority = &i
}
}
-// AddedRateMultiplier returns the value that was added to the "rate_multiplier" field in this mutation.
-func (m *GroupMutation) AddedRateMultiplier() (r float64, exists bool) {
- v := m.addrate_multiplier
+// AddedPriority returns the value that was added to the "priority" field in this mutation.
+func (m *ErrorPassthroughRuleMutation) AddedPriority() (r int, exists bool) {
+ v := m.addpriority
if v == nil {
return
}
return *v, true
}
-// ResetRateMultiplier resets all changes to the "rate_multiplier" field.
-func (m *GroupMutation) ResetRateMultiplier() {
- m.rate_multiplier = nil
- m.addrate_multiplier = nil
+// ResetPriority resets all changes to the "priority" field.
+func (m *ErrorPassthroughRuleMutation) ResetPriority() {
+ m.priority = nil
+ m.addpriority = nil
}
-// SetIsExclusive sets the "is_exclusive" field.
-func (m *GroupMutation) SetIsExclusive(b bool) {
- m.is_exclusive = &b
+// SetErrorCodes sets the "error_codes" field.
+func (m *ErrorPassthroughRuleMutation) SetErrorCodes(i []int) {
+ m.error_codes = &i
+ m.appenderror_codes = nil
}
-// IsExclusive returns the value of the "is_exclusive" field in the mutation.
-func (m *GroupMutation) IsExclusive() (r bool, exists bool) {
- v := m.is_exclusive
+// ErrorCodes returns the value of the "error_codes" field in the mutation.
+func (m *ErrorPassthroughRuleMutation) ErrorCodes() (r []int, exists bool) {
+ v := m.error_codes
if v == nil {
return
}
return *v, true
}
-// OldIsExclusive returns the old "is_exclusive" field's value of the Group entity.
-// If the Group object wasn't provided to the builder, the object is fetched from the database.
+// OldErrorCodes returns the old "error_codes" field's value of the ErrorPassthroughRule entity.
+// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *GroupMutation) OldIsExclusive(ctx context.Context) (v bool, err error) {
+func (m *ErrorPassthroughRuleMutation) OldErrorCodes(ctx context.Context) (v []int, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldIsExclusive is only allowed on UpdateOne operations")
+ return v, errors.New("OldErrorCodes is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldIsExclusive requires an ID field in the mutation")
+ return v, errors.New("OldErrorCodes requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldIsExclusive: %w", err)
+ return v, fmt.Errorf("querying old value for OldErrorCodes: %w", err)
}
- return oldValue.IsExclusive, nil
+ return oldValue.ErrorCodes, nil
}
-// ResetIsExclusive resets all changes to the "is_exclusive" field.
-func (m *GroupMutation) ResetIsExclusive() {
- m.is_exclusive = nil
+// AppendErrorCodes adds i to the "error_codes" field.
+func (m *ErrorPassthroughRuleMutation) AppendErrorCodes(i []int) {
+ m.appenderror_codes = append(m.appenderror_codes, i...)
}
-// SetStatus sets the "status" field.
-func (m *GroupMutation) SetStatus(s string) {
- m.status = &s
+// AppendedErrorCodes returns the list of values that were appended to the "error_codes" field in this mutation.
+func (m *ErrorPassthroughRuleMutation) AppendedErrorCodes() ([]int, bool) {
+ if len(m.appenderror_codes) == 0 {
+ return nil, false
+ }
+ return m.appenderror_codes, true
}
-// Status returns the value of the "status" field in the mutation.
-func (m *GroupMutation) Status() (r string, exists bool) {
- v := m.status
- if v == nil {
- return
- }
- return *v, true
+// ClearErrorCodes clears the value of the "error_codes" field.
+func (m *ErrorPassthroughRuleMutation) ClearErrorCodes() {
+ m.error_codes = nil
+ m.appenderror_codes = nil
+ m.clearedFields[errorpassthroughrule.FieldErrorCodes] = struct{}{}
}
-// OldStatus returns the old "status" field's value of the Group entity.
-// If the Group object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *GroupMutation) OldStatus(ctx context.Context) (v string, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldStatus is only allowed on UpdateOne operations")
- }
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldStatus requires an ID field in the mutation")
- }
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldStatus: %w", err)
- }
- return oldValue.Status, nil
+// ErrorCodesCleared returns if the "error_codes" field was cleared in this mutation.
+func (m *ErrorPassthroughRuleMutation) ErrorCodesCleared() bool {
+ _, ok := m.clearedFields[errorpassthroughrule.FieldErrorCodes]
+ return ok
}
-// ResetStatus resets all changes to the "status" field.
-func (m *GroupMutation) ResetStatus() {
- m.status = nil
+// ResetErrorCodes resets all changes to the "error_codes" field.
+func (m *ErrorPassthroughRuleMutation) ResetErrorCodes() {
+ m.error_codes = nil
+ m.appenderror_codes = nil
+ delete(m.clearedFields, errorpassthroughrule.FieldErrorCodes)
}
-// SetPlatform sets the "platform" field.
-func (m *GroupMutation) SetPlatform(s string) {
- m.platform = &s
+// SetKeywords sets the "keywords" field.
+func (m *ErrorPassthroughRuleMutation) SetKeywords(s []string) {
+ m.keywords = &s
+ m.appendkeywords = nil
}
-// Platform returns the value of the "platform" field in the mutation.
-func (m *GroupMutation) Platform() (r string, exists bool) {
- v := m.platform
+// Keywords returns the value of the "keywords" field in the mutation.
+func (m *ErrorPassthroughRuleMutation) Keywords() (r []string, exists bool) {
+ v := m.keywords
if v == nil {
return
}
return *v, true
}
-// OldPlatform returns the old "platform" field's value of the Group entity.
-// If the Group object wasn't provided to the builder, the object is fetched from the database.
+// OldKeywords returns the old "keywords" field's value of the ErrorPassthroughRule entity.
+// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *GroupMutation) OldPlatform(ctx context.Context) (v string, err error) {
+func (m *ErrorPassthroughRuleMutation) OldKeywords(ctx context.Context) (v []string, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldPlatform is only allowed on UpdateOne operations")
+ return v, errors.New("OldKeywords is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldPlatform requires an ID field in the mutation")
+ return v, errors.New("OldKeywords requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldPlatform: %w", err)
+ return v, fmt.Errorf("querying old value for OldKeywords: %w", err)
}
- return oldValue.Platform, nil
+ return oldValue.Keywords, nil
}
-// ResetPlatform resets all changes to the "platform" field.
-func (m *GroupMutation) ResetPlatform() {
- m.platform = nil
+// AppendKeywords adds s to the "keywords" field.
+func (m *ErrorPassthroughRuleMutation) AppendKeywords(s []string) {
+ m.appendkeywords = append(m.appendkeywords, s...)
}
-// SetSubscriptionType sets the "subscription_type" field.
-func (m *GroupMutation) SetSubscriptionType(s string) {
- m.subscription_type = &s
+// AppendedKeywords returns the list of values that were appended to the "keywords" field in this mutation.
+func (m *ErrorPassthroughRuleMutation) AppendedKeywords() ([]string, bool) {
+ if len(m.appendkeywords) == 0 {
+ return nil, false
+ }
+ return m.appendkeywords, true
}
-// SubscriptionType returns the value of the "subscription_type" field in the mutation.
-func (m *GroupMutation) SubscriptionType() (r string, exists bool) {
- v := m.subscription_type
+// ClearKeywords clears the value of the "keywords" field.
+func (m *ErrorPassthroughRuleMutation) ClearKeywords() {
+ m.keywords = nil
+ m.appendkeywords = nil
+ m.clearedFields[errorpassthroughrule.FieldKeywords] = struct{}{}
+}
+
+// KeywordsCleared returns if the "keywords" field was cleared in this mutation.
+func (m *ErrorPassthroughRuleMutation) KeywordsCleared() bool {
+ _, ok := m.clearedFields[errorpassthroughrule.FieldKeywords]
+ return ok
+}
+
+// ResetKeywords resets all changes to the "keywords" field.
+func (m *ErrorPassthroughRuleMutation) ResetKeywords() {
+ m.keywords = nil
+ m.appendkeywords = nil
+ delete(m.clearedFields, errorpassthroughrule.FieldKeywords)
+}
+
+// SetMatchMode sets the "match_mode" field.
+func (m *ErrorPassthroughRuleMutation) SetMatchMode(s string) {
+ m.match_mode = &s
+}
+
+// MatchMode returns the value of the "match_mode" field in the mutation.
+func (m *ErrorPassthroughRuleMutation) MatchMode() (r string, exists bool) {
+ v := m.match_mode
if v == nil {
return
}
return *v, true
}
-// OldSubscriptionType returns the old "subscription_type" field's value of the Group entity.
-// If the Group object wasn't provided to the builder, the object is fetched from the database.
+// OldMatchMode returns the old "match_mode" field's value of the ErrorPassthroughRule entity.
+// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *GroupMutation) OldSubscriptionType(ctx context.Context) (v string, err error) {
+func (m *ErrorPassthroughRuleMutation) OldMatchMode(ctx context.Context) (v string, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldSubscriptionType is only allowed on UpdateOne operations")
+ return v, errors.New("OldMatchMode is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldSubscriptionType requires an ID field in the mutation")
+ return v, errors.New("OldMatchMode requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldSubscriptionType: %w", err)
+ return v, fmt.Errorf("querying old value for OldMatchMode: %w", err)
}
- return oldValue.SubscriptionType, nil
+ return oldValue.MatchMode, nil
}
-// ResetSubscriptionType resets all changes to the "subscription_type" field.
-func (m *GroupMutation) ResetSubscriptionType() {
- m.subscription_type = nil
+// ResetMatchMode resets all changes to the "match_mode" field.
+func (m *ErrorPassthroughRuleMutation) ResetMatchMode() {
+ m.match_mode = nil
}
-// SetDailyLimitUsd sets the "daily_limit_usd" field.
-func (m *GroupMutation) SetDailyLimitUsd(f float64) {
- m.daily_limit_usd = &f
- m.adddaily_limit_usd = nil
+// SetPlatforms sets the "platforms" field.
+func (m *ErrorPassthroughRuleMutation) SetPlatforms(s []string) {
+ m.platforms = &s
+ m.appendplatforms = nil
}
-// DailyLimitUsd returns the value of the "daily_limit_usd" field in the mutation.
-func (m *GroupMutation) DailyLimitUsd() (r float64, exists bool) {
- v := m.daily_limit_usd
+// Platforms returns the value of the "platforms" field in the mutation.
+func (m *ErrorPassthroughRuleMutation) Platforms() (r []string, exists bool) {
+ v := m.platforms
if v == nil {
return
}
return *v, true
}
-// OldDailyLimitUsd returns the old "daily_limit_usd" field's value of the Group entity.
-// If the Group object wasn't provided to the builder, the object is fetched from the database.
+// OldPlatforms returns the old "platforms" field's value of the ErrorPassthroughRule entity.
+// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *GroupMutation) OldDailyLimitUsd(ctx context.Context) (v *float64, err error) {
+func (m *ErrorPassthroughRuleMutation) OldPlatforms(ctx context.Context) (v []string, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldDailyLimitUsd is only allowed on UpdateOne operations")
+ return v, errors.New("OldPlatforms is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldDailyLimitUsd requires an ID field in the mutation")
+ return v, errors.New("OldPlatforms requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldDailyLimitUsd: %w", err)
+ return v, fmt.Errorf("querying old value for OldPlatforms: %w", err)
}
- return oldValue.DailyLimitUsd, nil
+ return oldValue.Platforms, nil
}
-// AddDailyLimitUsd adds f to the "daily_limit_usd" field.
-func (m *GroupMutation) AddDailyLimitUsd(f float64) {
- if m.adddaily_limit_usd != nil {
- *m.adddaily_limit_usd += f
- } else {
- m.adddaily_limit_usd = &f
- }
+// AppendPlatforms adds s to the "platforms" field.
+func (m *ErrorPassthroughRuleMutation) AppendPlatforms(s []string) {
+ m.appendplatforms = append(m.appendplatforms, s...)
}
-// AddedDailyLimitUsd returns the value that was added to the "daily_limit_usd" field in this mutation.
-func (m *GroupMutation) AddedDailyLimitUsd() (r float64, exists bool) {
- v := m.adddaily_limit_usd
- if v == nil {
- return
+// AppendedPlatforms returns the list of values that were appended to the "platforms" field in this mutation.
+func (m *ErrorPassthroughRuleMutation) AppendedPlatforms() ([]string, bool) {
+ if len(m.appendplatforms) == 0 {
+ return nil, false
}
- return *v, true
+ return m.appendplatforms, true
}
-// ClearDailyLimitUsd clears the value of the "daily_limit_usd" field.
-func (m *GroupMutation) ClearDailyLimitUsd() {
- m.daily_limit_usd = nil
- m.adddaily_limit_usd = nil
- m.clearedFields[group.FieldDailyLimitUsd] = struct{}{}
+// ClearPlatforms clears the value of the "platforms" field.
+func (m *ErrorPassthroughRuleMutation) ClearPlatforms() {
+ m.platforms = nil
+ m.appendplatforms = nil
+ m.clearedFields[errorpassthroughrule.FieldPlatforms] = struct{}{}
}
-// DailyLimitUsdCleared returns if the "daily_limit_usd" field was cleared in this mutation.
-func (m *GroupMutation) DailyLimitUsdCleared() bool {
- _, ok := m.clearedFields[group.FieldDailyLimitUsd]
+// PlatformsCleared returns if the "platforms" field was cleared in this mutation.
+func (m *ErrorPassthroughRuleMutation) PlatformsCleared() bool {
+ _, ok := m.clearedFields[errorpassthroughrule.FieldPlatforms]
return ok
}
-// ResetDailyLimitUsd resets all changes to the "daily_limit_usd" field.
-func (m *GroupMutation) ResetDailyLimitUsd() {
- m.daily_limit_usd = nil
- m.adddaily_limit_usd = nil
- delete(m.clearedFields, group.FieldDailyLimitUsd)
+// ResetPlatforms resets all changes to the "platforms" field.
+func (m *ErrorPassthroughRuleMutation) ResetPlatforms() {
+ m.platforms = nil
+ m.appendplatforms = nil
+ delete(m.clearedFields, errorpassthroughrule.FieldPlatforms)
}
-// SetWeeklyLimitUsd sets the "weekly_limit_usd" field.
-func (m *GroupMutation) SetWeeklyLimitUsd(f float64) {
- m.weekly_limit_usd = &f
- m.addweekly_limit_usd = nil
+// SetPassthroughCode sets the "passthrough_code" field.
+func (m *ErrorPassthroughRuleMutation) SetPassthroughCode(b bool) {
+ m.passthrough_code = &b
}
-// WeeklyLimitUsd returns the value of the "weekly_limit_usd" field in the mutation.
-func (m *GroupMutation) WeeklyLimitUsd() (r float64, exists bool) {
- v := m.weekly_limit_usd
+// PassthroughCode returns the value of the "passthrough_code" field in the mutation.
+func (m *ErrorPassthroughRuleMutation) PassthroughCode() (r bool, exists bool) {
+ v := m.passthrough_code
if v == nil {
return
}
return *v, true
}
-// OldWeeklyLimitUsd returns the old "weekly_limit_usd" field's value of the Group entity.
-// If the Group object wasn't provided to the builder, the object is fetched from the database.
+// OldPassthroughCode returns the old "passthrough_code" field's value of the ErrorPassthroughRule entity.
+// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *GroupMutation) OldWeeklyLimitUsd(ctx context.Context) (v *float64, err error) {
+func (m *ErrorPassthroughRuleMutation) OldPassthroughCode(ctx context.Context) (v bool, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldWeeklyLimitUsd is only allowed on UpdateOne operations")
+ return v, errors.New("OldPassthroughCode is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldWeeklyLimitUsd requires an ID field in the mutation")
+ return v, errors.New("OldPassthroughCode requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldWeeklyLimitUsd: %w", err)
- }
- return oldValue.WeeklyLimitUsd, nil
-}
-
-// AddWeeklyLimitUsd adds f to the "weekly_limit_usd" field.
-func (m *GroupMutation) AddWeeklyLimitUsd(f float64) {
- if m.addweekly_limit_usd != nil {
- *m.addweekly_limit_usd += f
- } else {
- m.addweekly_limit_usd = &f
- }
-}
-
-// AddedWeeklyLimitUsd returns the value that was added to the "weekly_limit_usd" field in this mutation.
-func (m *GroupMutation) AddedWeeklyLimitUsd() (r float64, exists bool) {
- v := m.addweekly_limit_usd
- if v == nil {
- return
+ return v, fmt.Errorf("querying old value for OldPassthroughCode: %w", err)
}
- return *v, true
-}
-
-// ClearWeeklyLimitUsd clears the value of the "weekly_limit_usd" field.
-func (m *GroupMutation) ClearWeeklyLimitUsd() {
- m.weekly_limit_usd = nil
- m.addweekly_limit_usd = nil
- m.clearedFields[group.FieldWeeklyLimitUsd] = struct{}{}
-}
-
-// WeeklyLimitUsdCleared returns if the "weekly_limit_usd" field was cleared in this mutation.
-func (m *GroupMutation) WeeklyLimitUsdCleared() bool {
- _, ok := m.clearedFields[group.FieldWeeklyLimitUsd]
- return ok
+ return oldValue.PassthroughCode, nil
}
-// ResetWeeklyLimitUsd resets all changes to the "weekly_limit_usd" field.
-func (m *GroupMutation) ResetWeeklyLimitUsd() {
- m.weekly_limit_usd = nil
- m.addweekly_limit_usd = nil
- delete(m.clearedFields, group.FieldWeeklyLimitUsd)
+// ResetPassthroughCode resets all changes to the "passthrough_code" field.
+func (m *ErrorPassthroughRuleMutation) ResetPassthroughCode() {
+ m.passthrough_code = nil
}
-// SetMonthlyLimitUsd sets the "monthly_limit_usd" field.
-func (m *GroupMutation) SetMonthlyLimitUsd(f float64) {
- m.monthly_limit_usd = &f
- m.addmonthly_limit_usd = nil
+// SetResponseCode sets the "response_code" field.
+func (m *ErrorPassthroughRuleMutation) SetResponseCode(i int) {
+ m.response_code = &i
+ m.addresponse_code = nil
}
-// MonthlyLimitUsd returns the value of the "monthly_limit_usd" field in the mutation.
-func (m *GroupMutation) MonthlyLimitUsd() (r float64, exists bool) {
- v := m.monthly_limit_usd
+// ResponseCode returns the value of the "response_code" field in the mutation.
+func (m *ErrorPassthroughRuleMutation) ResponseCode() (r int, exists bool) {
+ v := m.response_code
if v == nil {
return
}
return *v, true
}
-// OldMonthlyLimitUsd returns the old "monthly_limit_usd" field's value of the Group entity.
-// If the Group object wasn't provided to the builder, the object is fetched from the database.
+// OldResponseCode returns the old "response_code" field's value of the ErrorPassthroughRule entity.
+// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *GroupMutation) OldMonthlyLimitUsd(ctx context.Context) (v *float64, err error) {
+func (m *ErrorPassthroughRuleMutation) OldResponseCode(ctx context.Context) (v *int, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldMonthlyLimitUsd is only allowed on UpdateOne operations")
+ return v, errors.New("OldResponseCode is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldMonthlyLimitUsd requires an ID field in the mutation")
+ return v, errors.New("OldResponseCode requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldMonthlyLimitUsd: %w", err)
+ return v, fmt.Errorf("querying old value for OldResponseCode: %w", err)
}
- return oldValue.MonthlyLimitUsd, nil
+ return oldValue.ResponseCode, nil
}
-// AddMonthlyLimitUsd adds f to the "monthly_limit_usd" field.
-func (m *GroupMutation) AddMonthlyLimitUsd(f float64) {
- if m.addmonthly_limit_usd != nil {
- *m.addmonthly_limit_usd += f
+// AddResponseCode adds i to the "response_code" field.
+func (m *ErrorPassthroughRuleMutation) AddResponseCode(i int) {
+ if m.addresponse_code != nil {
+ *m.addresponse_code += i
} else {
- m.addmonthly_limit_usd = &f
+ m.addresponse_code = &i
}
}
-// AddedMonthlyLimitUsd returns the value that was added to the "monthly_limit_usd" field in this mutation.
-func (m *GroupMutation) AddedMonthlyLimitUsd() (r float64, exists bool) {
- v := m.addmonthly_limit_usd
+// AddedResponseCode returns the value that was added to the "response_code" field in this mutation.
+func (m *ErrorPassthroughRuleMutation) AddedResponseCode() (r int, exists bool) {
+ v := m.addresponse_code
if v == nil {
return
}
return *v, true
}
-// ClearMonthlyLimitUsd clears the value of the "monthly_limit_usd" field.
-func (m *GroupMutation) ClearMonthlyLimitUsd() {
- m.monthly_limit_usd = nil
- m.addmonthly_limit_usd = nil
- m.clearedFields[group.FieldMonthlyLimitUsd] = struct{}{}
+// ClearResponseCode clears the value of the "response_code" field.
+func (m *ErrorPassthroughRuleMutation) ClearResponseCode() {
+ m.response_code = nil
+ m.addresponse_code = nil
+ m.clearedFields[errorpassthroughrule.FieldResponseCode] = struct{}{}
}
-// MonthlyLimitUsdCleared returns if the "monthly_limit_usd" field was cleared in this mutation.
-func (m *GroupMutation) MonthlyLimitUsdCleared() bool {
- _, ok := m.clearedFields[group.FieldMonthlyLimitUsd]
+// ResponseCodeCleared returns if the "response_code" field was cleared in this mutation.
+func (m *ErrorPassthroughRuleMutation) ResponseCodeCleared() bool {
+ _, ok := m.clearedFields[errorpassthroughrule.FieldResponseCode]
return ok
}
-// ResetMonthlyLimitUsd resets all changes to the "monthly_limit_usd" field.
-func (m *GroupMutation) ResetMonthlyLimitUsd() {
- m.monthly_limit_usd = nil
- m.addmonthly_limit_usd = nil
- delete(m.clearedFields, group.FieldMonthlyLimitUsd)
+// ResetResponseCode resets all changes to the "response_code" field.
+func (m *ErrorPassthroughRuleMutation) ResetResponseCode() {
+ m.response_code = nil
+ m.addresponse_code = nil
+ delete(m.clearedFields, errorpassthroughrule.FieldResponseCode)
}
-// SetDefaultValidityDays sets the "default_validity_days" field.
-func (m *GroupMutation) SetDefaultValidityDays(i int) {
- m.default_validity_days = &i
- m.adddefault_validity_days = nil
+// SetPassthroughBody sets the "passthrough_body" field.
+func (m *ErrorPassthroughRuleMutation) SetPassthroughBody(b bool) {
+ m.passthrough_body = &b
}
-// DefaultValidityDays returns the value of the "default_validity_days" field in the mutation.
-func (m *GroupMutation) DefaultValidityDays() (r int, exists bool) {
- v := m.default_validity_days
+// PassthroughBody returns the value of the "passthrough_body" field in the mutation.
+func (m *ErrorPassthroughRuleMutation) PassthroughBody() (r bool, exists bool) {
+ v := m.passthrough_body
if v == nil {
return
}
return *v, true
}
-// OldDefaultValidityDays returns the old "default_validity_days" field's value of the Group entity.
-// If the Group object wasn't provided to the builder, the object is fetched from the database.
+// OldPassthroughBody returns the old "passthrough_body" field's value of the ErrorPassthroughRule entity.
+// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *GroupMutation) OldDefaultValidityDays(ctx context.Context) (v int, err error) {
+func (m *ErrorPassthroughRuleMutation) OldPassthroughBody(ctx context.Context) (v bool, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldDefaultValidityDays is only allowed on UpdateOne operations")
+ return v, errors.New("OldPassthroughBody is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldDefaultValidityDays requires an ID field in the mutation")
+ return v, errors.New("OldPassthroughBody requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldDefaultValidityDays: %w", err)
- }
- return oldValue.DefaultValidityDays, nil
-}
-
-// AddDefaultValidityDays adds i to the "default_validity_days" field.
-func (m *GroupMutation) AddDefaultValidityDays(i int) {
- if m.adddefault_validity_days != nil {
- *m.adddefault_validity_days += i
- } else {
- m.adddefault_validity_days = &i
- }
-}
-
-// AddedDefaultValidityDays returns the value that was added to the "default_validity_days" field in this mutation.
-func (m *GroupMutation) AddedDefaultValidityDays() (r int, exists bool) {
- v := m.adddefault_validity_days
- if v == nil {
- return
+ return v, fmt.Errorf("querying old value for OldPassthroughBody: %w", err)
}
- return *v, true
+ return oldValue.PassthroughBody, nil
}
-// ResetDefaultValidityDays resets all changes to the "default_validity_days" field.
-func (m *GroupMutation) ResetDefaultValidityDays() {
- m.default_validity_days = nil
- m.adddefault_validity_days = nil
+// ResetPassthroughBody resets all changes to the "passthrough_body" field.
+func (m *ErrorPassthroughRuleMutation) ResetPassthroughBody() {
+ m.passthrough_body = nil
}
-// SetImagePrice1k sets the "image_price_1k" field.
-func (m *GroupMutation) SetImagePrice1k(f float64) {
- m.image_price_1k = &f
- m.addimage_price_1k = nil
+// SetCustomMessage sets the "custom_message" field.
+func (m *ErrorPassthroughRuleMutation) SetCustomMessage(s string) {
+ m.custom_message = &s
}
-// ImagePrice1k returns the value of the "image_price_1k" field in the mutation.
-func (m *GroupMutation) ImagePrice1k() (r float64, exists bool) {
- v := m.image_price_1k
+// CustomMessage returns the value of the "custom_message" field in the mutation.
+func (m *ErrorPassthroughRuleMutation) CustomMessage() (r string, exists bool) {
+ v := m.custom_message
if v == nil {
return
}
return *v, true
}
-// OldImagePrice1k returns the old "image_price_1k" field's value of the Group entity.
-// If the Group object wasn't provided to the builder, the object is fetched from the database.
+// OldCustomMessage returns the old "custom_message" field's value of the ErrorPassthroughRule entity.
+// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *GroupMutation) OldImagePrice1k(ctx context.Context) (v *float64, err error) {
+func (m *ErrorPassthroughRuleMutation) OldCustomMessage(ctx context.Context) (v *string, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldImagePrice1k is only allowed on UpdateOne operations")
+ return v, errors.New("OldCustomMessage is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldImagePrice1k requires an ID field in the mutation")
+ return v, errors.New("OldCustomMessage requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldImagePrice1k: %w", err)
+ return v, fmt.Errorf("querying old value for OldCustomMessage: %w", err)
}
- return oldValue.ImagePrice1k, nil
+ return oldValue.CustomMessage, nil
}
-// AddImagePrice1k adds f to the "image_price_1k" field.
-func (m *GroupMutation) AddImagePrice1k(f float64) {
- if m.addimage_price_1k != nil {
- *m.addimage_price_1k += f
- } else {
- m.addimage_price_1k = &f
- }
+// ClearCustomMessage clears the value of the "custom_message" field.
+func (m *ErrorPassthroughRuleMutation) ClearCustomMessage() {
+ m.custom_message = nil
+ m.clearedFields[errorpassthroughrule.FieldCustomMessage] = struct{}{}
}
-// AddedImagePrice1k returns the value that was added to the "image_price_1k" field in this mutation.
-func (m *GroupMutation) AddedImagePrice1k() (r float64, exists bool) {
- v := m.addimage_price_1k
- if v == nil {
- return
- }
- return *v, true
+// CustomMessageCleared returns if the "custom_message" field was cleared in this mutation.
+func (m *ErrorPassthroughRuleMutation) CustomMessageCleared() bool {
+ _, ok := m.clearedFields[errorpassthroughrule.FieldCustomMessage]
+ return ok
}
-// ClearImagePrice1k clears the value of the "image_price_1k" field.
-func (m *GroupMutation) ClearImagePrice1k() {
- m.image_price_1k = nil
- m.addimage_price_1k = nil
- m.clearedFields[group.FieldImagePrice1k] = struct{}{}
+// ResetCustomMessage resets all changes to the "custom_message" field.
+func (m *ErrorPassthroughRuleMutation) ResetCustomMessage() {
+ m.custom_message = nil
+ delete(m.clearedFields, errorpassthroughrule.FieldCustomMessage)
}
-// ImagePrice1kCleared returns if the "image_price_1k" field was cleared in this mutation.
-func (m *GroupMutation) ImagePrice1kCleared() bool {
- _, ok := m.clearedFields[group.FieldImagePrice1k]
- return ok
+// SetSkipMonitoring sets the "skip_monitoring" field.
+func (m *ErrorPassthroughRuleMutation) SetSkipMonitoring(b bool) {
+ m.skip_monitoring = &b
}
-// ResetImagePrice1k resets all changes to the "image_price_1k" field.
-func (m *GroupMutation) ResetImagePrice1k() {
- m.image_price_1k = nil
- m.addimage_price_1k = nil
- delete(m.clearedFields, group.FieldImagePrice1k)
-}
-
-// SetImagePrice2k sets the "image_price_2k" field.
-func (m *GroupMutation) SetImagePrice2k(f float64) {
- m.image_price_2k = &f
- m.addimage_price_2k = nil
-}
-
-// ImagePrice2k returns the value of the "image_price_2k" field in the mutation.
-func (m *GroupMutation) ImagePrice2k() (r float64, exists bool) {
- v := m.image_price_2k
+// SkipMonitoring returns the value of the "skip_monitoring" field in the mutation.
+func (m *ErrorPassthroughRuleMutation) SkipMonitoring() (r bool, exists bool) {
+ v := m.skip_monitoring
if v == nil {
return
}
return *v, true
}
-// OldImagePrice2k returns the old "image_price_2k" field's value of the Group entity.
-// If the Group object wasn't provided to the builder, the object is fetched from the database.
+// OldSkipMonitoring returns the old "skip_monitoring" field's value of the ErrorPassthroughRule entity.
+// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *GroupMutation) OldImagePrice2k(ctx context.Context) (v *float64, err error) {
+func (m *ErrorPassthroughRuleMutation) OldSkipMonitoring(ctx context.Context) (v bool, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldImagePrice2k is only allowed on UpdateOne operations")
+ return v, errors.New("OldSkipMonitoring is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldImagePrice2k requires an ID field in the mutation")
+ return v, errors.New("OldSkipMonitoring requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldImagePrice2k: %w", err)
- }
- return oldValue.ImagePrice2k, nil
-}
-
-// AddImagePrice2k adds f to the "image_price_2k" field.
-func (m *GroupMutation) AddImagePrice2k(f float64) {
- if m.addimage_price_2k != nil {
- *m.addimage_price_2k += f
- } else {
- m.addimage_price_2k = &f
- }
-}
-
-// AddedImagePrice2k returns the value that was added to the "image_price_2k" field in this mutation.
-func (m *GroupMutation) AddedImagePrice2k() (r float64, exists bool) {
- v := m.addimage_price_2k
- if v == nil {
- return
+ return v, fmt.Errorf("querying old value for OldSkipMonitoring: %w", err)
}
- return *v, true
-}
-
-// ClearImagePrice2k clears the value of the "image_price_2k" field.
-func (m *GroupMutation) ClearImagePrice2k() {
- m.image_price_2k = nil
- m.addimage_price_2k = nil
- m.clearedFields[group.FieldImagePrice2k] = struct{}{}
-}
-
-// ImagePrice2kCleared returns if the "image_price_2k" field was cleared in this mutation.
-func (m *GroupMutation) ImagePrice2kCleared() bool {
- _, ok := m.clearedFields[group.FieldImagePrice2k]
- return ok
+ return oldValue.SkipMonitoring, nil
}
-// ResetImagePrice2k resets all changes to the "image_price_2k" field.
-func (m *GroupMutation) ResetImagePrice2k() {
- m.image_price_2k = nil
- m.addimage_price_2k = nil
- delete(m.clearedFields, group.FieldImagePrice2k)
+// ResetSkipMonitoring resets all changes to the "skip_monitoring" field.
+func (m *ErrorPassthroughRuleMutation) ResetSkipMonitoring() {
+ m.skip_monitoring = nil
}
-// SetImagePrice4k sets the "image_price_4k" field.
-func (m *GroupMutation) SetImagePrice4k(f float64) {
- m.image_price_4k = &f
- m.addimage_price_4k = nil
+// SetDescription sets the "description" field.
+func (m *ErrorPassthroughRuleMutation) SetDescription(s string) {
+ m.description = &s
}
-// ImagePrice4k returns the value of the "image_price_4k" field in the mutation.
-func (m *GroupMutation) ImagePrice4k() (r float64, exists bool) {
- v := m.image_price_4k
+// Description returns the value of the "description" field in the mutation.
+func (m *ErrorPassthroughRuleMutation) Description() (r string, exists bool) {
+ v := m.description
if v == nil {
return
}
return *v, true
}
-// OldImagePrice4k returns the old "image_price_4k" field's value of the Group entity.
-// If the Group object wasn't provided to the builder, the object is fetched from the database.
+// OldDescription returns the old "description" field's value of the ErrorPassthroughRule entity.
+// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *GroupMutation) OldImagePrice4k(ctx context.Context) (v *float64, err error) {
+func (m *ErrorPassthroughRuleMutation) OldDescription(ctx context.Context) (v *string, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldImagePrice4k is only allowed on UpdateOne operations")
+ return v, errors.New("OldDescription is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldImagePrice4k requires an ID field in the mutation")
+ return v, errors.New("OldDescription requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldImagePrice4k: %w", err)
- }
- return oldValue.ImagePrice4k, nil
-}
-
-// AddImagePrice4k adds f to the "image_price_4k" field.
-func (m *GroupMutation) AddImagePrice4k(f float64) {
- if m.addimage_price_4k != nil {
- *m.addimage_price_4k += f
- } else {
- m.addimage_price_4k = &f
- }
-}
-
-// AddedImagePrice4k returns the value that was added to the "image_price_4k" field in this mutation.
-func (m *GroupMutation) AddedImagePrice4k() (r float64, exists bool) {
- v := m.addimage_price_4k
- if v == nil {
- return
+ return v, fmt.Errorf("querying old value for OldDescription: %w", err)
}
- return *v, true
+ return oldValue.Description, nil
}
-// ClearImagePrice4k clears the value of the "image_price_4k" field.
-func (m *GroupMutation) ClearImagePrice4k() {
- m.image_price_4k = nil
- m.addimage_price_4k = nil
- m.clearedFields[group.FieldImagePrice4k] = struct{}{}
+// ClearDescription clears the value of the "description" field.
+func (m *ErrorPassthroughRuleMutation) ClearDescription() {
+ m.description = nil
+ m.clearedFields[errorpassthroughrule.FieldDescription] = struct{}{}
}
-// ImagePrice4kCleared returns if the "image_price_4k" field was cleared in this mutation.
-func (m *GroupMutation) ImagePrice4kCleared() bool {
- _, ok := m.clearedFields[group.FieldImagePrice4k]
+// DescriptionCleared returns if the "description" field was cleared in this mutation.
+func (m *ErrorPassthroughRuleMutation) DescriptionCleared() bool {
+ _, ok := m.clearedFields[errorpassthroughrule.FieldDescription]
return ok
}
-// ResetImagePrice4k resets all changes to the "image_price_4k" field.
-func (m *GroupMutation) ResetImagePrice4k() {
- m.image_price_4k = nil
- m.addimage_price_4k = nil
- delete(m.clearedFields, group.FieldImagePrice4k)
+// ResetDescription resets all changes to the "description" field.
+func (m *ErrorPassthroughRuleMutation) ResetDescription() {
+ m.description = nil
+ delete(m.clearedFields, errorpassthroughrule.FieldDescription)
}
-// SetClaudeCodeOnly sets the "claude_code_only" field.
-func (m *GroupMutation) SetClaudeCodeOnly(b bool) {
- m.claude_code_only = &b
+// Where appends a list predicates to the ErrorPassthroughRuleMutation builder.
+func (m *ErrorPassthroughRuleMutation) Where(ps ...predicate.ErrorPassthroughRule) {
+ m.predicates = append(m.predicates, ps...)
}
-// ClaudeCodeOnly returns the value of the "claude_code_only" field in the mutation.
-func (m *GroupMutation) ClaudeCodeOnly() (r bool, exists bool) {
- v := m.claude_code_only
- if v == nil {
- return
+// WhereP appends storage-level predicates to the ErrorPassthroughRuleMutation builder. Using this method,
+// users can use type-assertion to append predicates that do not depend on any generated package.
+func (m *ErrorPassthroughRuleMutation) WhereP(ps ...func(*sql.Selector)) {
+ p := make([]predicate.ErrorPassthroughRule, len(ps))
+ for i := range ps {
+ p[i] = ps[i]
}
- return *v, true
+ m.Where(p...)
}
-// OldClaudeCodeOnly returns the old "claude_code_only" field's value of the Group entity.
-// If the Group object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *GroupMutation) OldClaudeCodeOnly(ctx context.Context) (v bool, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldClaudeCodeOnly is only allowed on UpdateOne operations")
- }
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldClaudeCodeOnly requires an ID field in the mutation")
- }
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldClaudeCodeOnly: %w", err)
- }
- return oldValue.ClaudeCodeOnly, nil
+// Op returns the operation name.
+func (m *ErrorPassthroughRuleMutation) Op() Op {
+ return m.op
}
-// ResetClaudeCodeOnly resets all changes to the "claude_code_only" field.
-func (m *GroupMutation) ResetClaudeCodeOnly() {
- m.claude_code_only = nil
+// SetOp allows setting the mutation operation.
+func (m *ErrorPassthroughRuleMutation) SetOp(op Op) {
+ m.op = op
}
-// SetFallbackGroupID sets the "fallback_group_id" field.
-func (m *GroupMutation) SetFallbackGroupID(i int64) {
- m.fallback_group_id = &i
- m.addfallback_group_id = nil
+// Type returns the node type of this mutation (ErrorPassthroughRule).
+func (m *ErrorPassthroughRuleMutation) Type() string {
+ return m.typ
}
-// FallbackGroupID returns the value of the "fallback_group_id" field in the mutation.
-func (m *GroupMutation) FallbackGroupID() (r int64, exists bool) {
- v := m.fallback_group_id
- if v == nil {
- return
+// Fields returns all fields that were changed during this mutation. Note that in
+// order to get all numeric fields that were incremented/decremented, call
+// AddedFields().
+func (m *ErrorPassthroughRuleMutation) Fields() []string {
+ fields := make([]string, 0, 15)
+ if m.created_at != nil {
+ fields = append(fields, errorpassthroughrule.FieldCreatedAt)
}
- return *v, true
-}
-
-// OldFallbackGroupID returns the old "fallback_group_id" field's value of the Group entity.
-// If the Group object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *GroupMutation) OldFallbackGroupID(ctx context.Context) (v *int64, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldFallbackGroupID is only allowed on UpdateOne operations")
+ if m.updated_at != nil {
+ fields = append(fields, errorpassthroughrule.FieldUpdatedAt)
}
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldFallbackGroupID requires an ID field in the mutation")
+ if m.name != nil {
+ fields = append(fields, errorpassthroughrule.FieldName)
}
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldFallbackGroupID: %w", err)
+ if m.enabled != nil {
+ fields = append(fields, errorpassthroughrule.FieldEnabled)
}
- return oldValue.FallbackGroupID, nil
-}
-
-// AddFallbackGroupID adds i to the "fallback_group_id" field.
-func (m *GroupMutation) AddFallbackGroupID(i int64) {
- if m.addfallback_group_id != nil {
- *m.addfallback_group_id += i
- } else {
- m.addfallback_group_id = &i
+ if m.priority != nil {
+ fields = append(fields, errorpassthroughrule.FieldPriority)
}
-}
-
-// AddedFallbackGroupID returns the value that was added to the "fallback_group_id" field in this mutation.
-func (m *GroupMutation) AddedFallbackGroupID() (r int64, exists bool) {
- v := m.addfallback_group_id
- if v == nil {
- return
+ if m.error_codes != nil {
+ fields = append(fields, errorpassthroughrule.FieldErrorCodes)
}
- return *v, true
-}
-
-// ClearFallbackGroupID clears the value of the "fallback_group_id" field.
-func (m *GroupMutation) ClearFallbackGroupID() {
- m.fallback_group_id = nil
- m.addfallback_group_id = nil
- m.clearedFields[group.FieldFallbackGroupID] = struct{}{}
-}
-
-// FallbackGroupIDCleared returns if the "fallback_group_id" field was cleared in this mutation.
-func (m *GroupMutation) FallbackGroupIDCleared() bool {
- _, ok := m.clearedFields[group.FieldFallbackGroupID]
- return ok
-}
-
-// ResetFallbackGroupID resets all changes to the "fallback_group_id" field.
-func (m *GroupMutation) ResetFallbackGroupID() {
- m.fallback_group_id = nil
- m.addfallback_group_id = nil
- delete(m.clearedFields, group.FieldFallbackGroupID)
-}
-
-// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field.
-func (m *GroupMutation) SetFallbackGroupIDOnInvalidRequest(i int64) {
- m.fallback_group_id_on_invalid_request = &i
- m.addfallback_group_id_on_invalid_request = nil
-}
-
-// FallbackGroupIDOnInvalidRequest returns the value of the "fallback_group_id_on_invalid_request" field in the mutation.
-func (m *GroupMutation) FallbackGroupIDOnInvalidRequest() (r int64, exists bool) {
- v := m.fallback_group_id_on_invalid_request
- if v == nil {
- return
- }
- return *v, true
-}
-
-// OldFallbackGroupIDOnInvalidRequest returns the old "fallback_group_id_on_invalid_request" field's value of the Group entity.
-// If the Group object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *GroupMutation) OldFallbackGroupIDOnInvalidRequest(ctx context.Context) (v *int64, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldFallbackGroupIDOnInvalidRequest is only allowed on UpdateOne operations")
- }
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldFallbackGroupIDOnInvalidRequest requires an ID field in the mutation")
+ if m.keywords != nil {
+ fields = append(fields, errorpassthroughrule.FieldKeywords)
}
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldFallbackGroupIDOnInvalidRequest: %w", err)
+ if m.match_mode != nil {
+ fields = append(fields, errorpassthroughrule.FieldMatchMode)
}
- return oldValue.FallbackGroupIDOnInvalidRequest, nil
-}
-
-// AddFallbackGroupIDOnInvalidRequest adds i to the "fallback_group_id_on_invalid_request" field.
-func (m *GroupMutation) AddFallbackGroupIDOnInvalidRequest(i int64) {
- if m.addfallback_group_id_on_invalid_request != nil {
- *m.addfallback_group_id_on_invalid_request += i
- } else {
- m.addfallback_group_id_on_invalid_request = &i
+ if m.platforms != nil {
+ fields = append(fields, errorpassthroughrule.FieldPlatforms)
}
-}
-
-// AddedFallbackGroupIDOnInvalidRequest returns the value that was added to the "fallback_group_id_on_invalid_request" field in this mutation.
-func (m *GroupMutation) AddedFallbackGroupIDOnInvalidRequest() (r int64, exists bool) {
- v := m.addfallback_group_id_on_invalid_request
- if v == nil {
- return
+ if m.passthrough_code != nil {
+ fields = append(fields, errorpassthroughrule.FieldPassthroughCode)
}
- return *v, true
-}
-
-// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field.
-func (m *GroupMutation) ClearFallbackGroupIDOnInvalidRequest() {
- m.fallback_group_id_on_invalid_request = nil
- m.addfallback_group_id_on_invalid_request = nil
- m.clearedFields[group.FieldFallbackGroupIDOnInvalidRequest] = struct{}{}
-}
-
-// FallbackGroupIDOnInvalidRequestCleared returns if the "fallback_group_id_on_invalid_request" field was cleared in this mutation.
-func (m *GroupMutation) FallbackGroupIDOnInvalidRequestCleared() bool {
- _, ok := m.clearedFields[group.FieldFallbackGroupIDOnInvalidRequest]
- return ok
-}
-
-// ResetFallbackGroupIDOnInvalidRequest resets all changes to the "fallback_group_id_on_invalid_request" field.
-func (m *GroupMutation) ResetFallbackGroupIDOnInvalidRequest() {
- m.fallback_group_id_on_invalid_request = nil
- m.addfallback_group_id_on_invalid_request = nil
- delete(m.clearedFields, group.FieldFallbackGroupIDOnInvalidRequest)
-}
-
-// SetModelRouting sets the "model_routing" field.
-func (m *GroupMutation) SetModelRouting(value map[string][]int64) {
- m.model_routing = &value
-}
-
-// ModelRouting returns the value of the "model_routing" field in the mutation.
-func (m *GroupMutation) ModelRouting() (r map[string][]int64, exists bool) {
- v := m.model_routing
- if v == nil {
- return
+ if m.response_code != nil {
+ fields = append(fields, errorpassthroughrule.FieldResponseCode)
}
- return *v, true
-}
-
-// OldModelRouting returns the old "model_routing" field's value of the Group entity.
-// If the Group object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *GroupMutation) OldModelRouting(ctx context.Context) (v map[string][]int64, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldModelRouting is only allowed on UpdateOne operations")
+ if m.passthrough_body != nil {
+ fields = append(fields, errorpassthroughrule.FieldPassthroughBody)
}
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldModelRouting requires an ID field in the mutation")
+ if m.custom_message != nil {
+ fields = append(fields, errorpassthroughrule.FieldCustomMessage)
}
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldModelRouting: %w", err)
+ if m.skip_monitoring != nil {
+ fields = append(fields, errorpassthroughrule.FieldSkipMonitoring)
}
- return oldValue.ModelRouting, nil
-}
-
-// ClearModelRouting clears the value of the "model_routing" field.
-func (m *GroupMutation) ClearModelRouting() {
- m.model_routing = nil
- m.clearedFields[group.FieldModelRouting] = struct{}{}
-}
-
-// ModelRoutingCleared returns if the "model_routing" field was cleared in this mutation.
-func (m *GroupMutation) ModelRoutingCleared() bool {
- _, ok := m.clearedFields[group.FieldModelRouting]
- return ok
-}
-
-// ResetModelRouting resets all changes to the "model_routing" field.
-func (m *GroupMutation) ResetModelRouting() {
- m.model_routing = nil
- delete(m.clearedFields, group.FieldModelRouting)
-}
-
-// SetModelRoutingEnabled sets the "model_routing_enabled" field.
-func (m *GroupMutation) SetModelRoutingEnabled(b bool) {
- m.model_routing_enabled = &b
-}
-
-// ModelRoutingEnabled returns the value of the "model_routing_enabled" field in the mutation.
-func (m *GroupMutation) ModelRoutingEnabled() (r bool, exists bool) {
- v := m.model_routing_enabled
- if v == nil {
- return
+ if m.description != nil {
+ fields = append(fields, errorpassthroughrule.FieldDescription)
}
- return *v, true
+ return fields
}
-// OldModelRoutingEnabled returns the old "model_routing_enabled" field's value of the Group entity.
-// If the Group object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *GroupMutation) OldModelRoutingEnabled(ctx context.Context) (v bool, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldModelRoutingEnabled is only allowed on UpdateOne operations")
- }
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldModelRoutingEnabled requires an ID field in the mutation")
- }
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldModelRoutingEnabled: %w", err)
+// Field returns the value of a field with the given name. The second boolean
+// return value indicates that this field was not set, or was not defined in the
+// schema.
+func (m *ErrorPassthroughRuleMutation) Field(name string) (ent.Value, bool) {
+ switch name {
+ case errorpassthroughrule.FieldCreatedAt:
+ return m.CreatedAt()
+ case errorpassthroughrule.FieldUpdatedAt:
+ return m.UpdatedAt()
+ case errorpassthroughrule.FieldName:
+ return m.Name()
+ case errorpassthroughrule.FieldEnabled:
+ return m.Enabled()
+ case errorpassthroughrule.FieldPriority:
+ return m.Priority()
+ case errorpassthroughrule.FieldErrorCodes:
+ return m.ErrorCodes()
+ case errorpassthroughrule.FieldKeywords:
+ return m.Keywords()
+ case errorpassthroughrule.FieldMatchMode:
+ return m.MatchMode()
+ case errorpassthroughrule.FieldPlatforms:
+ return m.Platforms()
+ case errorpassthroughrule.FieldPassthroughCode:
+ return m.PassthroughCode()
+ case errorpassthroughrule.FieldResponseCode:
+ return m.ResponseCode()
+ case errorpassthroughrule.FieldPassthroughBody:
+ return m.PassthroughBody()
+ case errorpassthroughrule.FieldCustomMessage:
+ return m.CustomMessage()
+ case errorpassthroughrule.FieldSkipMonitoring:
+ return m.SkipMonitoring()
+ case errorpassthroughrule.FieldDescription:
+ return m.Description()
}
- return oldValue.ModelRoutingEnabled, nil
-}
-
-// ResetModelRoutingEnabled resets all changes to the "model_routing_enabled" field.
-func (m *GroupMutation) ResetModelRoutingEnabled() {
- m.model_routing_enabled = nil
-}
-
-// SetMcpXMLInject sets the "mcp_xml_inject" field.
-func (m *GroupMutation) SetMcpXMLInject(b bool) {
- m.mcp_xml_inject = &b
+ return nil, false
}
-// McpXMLInject returns the value of the "mcp_xml_inject" field in the mutation.
-func (m *GroupMutation) McpXMLInject() (r bool, exists bool) {
- v := m.mcp_xml_inject
- if v == nil {
- return
+// OldField returns the old value of the field from the database. An error is
+// returned if the mutation operation is not UpdateOne, or the query to the
+// database failed.
+func (m *ErrorPassthroughRuleMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
+ switch name {
+ case errorpassthroughrule.FieldCreatedAt:
+ return m.OldCreatedAt(ctx)
+ case errorpassthroughrule.FieldUpdatedAt:
+ return m.OldUpdatedAt(ctx)
+ case errorpassthroughrule.FieldName:
+ return m.OldName(ctx)
+ case errorpassthroughrule.FieldEnabled:
+ return m.OldEnabled(ctx)
+ case errorpassthroughrule.FieldPriority:
+ return m.OldPriority(ctx)
+ case errorpassthroughrule.FieldErrorCodes:
+ return m.OldErrorCodes(ctx)
+ case errorpassthroughrule.FieldKeywords:
+ return m.OldKeywords(ctx)
+ case errorpassthroughrule.FieldMatchMode:
+ return m.OldMatchMode(ctx)
+ case errorpassthroughrule.FieldPlatforms:
+ return m.OldPlatforms(ctx)
+ case errorpassthroughrule.FieldPassthroughCode:
+ return m.OldPassthroughCode(ctx)
+ case errorpassthroughrule.FieldResponseCode:
+ return m.OldResponseCode(ctx)
+ case errorpassthroughrule.FieldPassthroughBody:
+ return m.OldPassthroughBody(ctx)
+ case errorpassthroughrule.FieldCustomMessage:
+ return m.OldCustomMessage(ctx)
+ case errorpassthroughrule.FieldSkipMonitoring:
+ return m.OldSkipMonitoring(ctx)
+ case errorpassthroughrule.FieldDescription:
+ return m.OldDescription(ctx)
}
- return *v, true
+ return nil, fmt.Errorf("unknown ErrorPassthroughRule field %s", name)
}
-// OldMcpXMLInject returns the old "mcp_xml_inject" field's value of the Group entity.
-// If the Group object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *GroupMutation) OldMcpXMLInject(ctx context.Context) (v bool, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldMcpXMLInject is only allowed on UpdateOne operations")
- }
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldMcpXMLInject requires an ID field in the mutation")
- }
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldMcpXMLInject: %w", err)
+// SetField sets the value of a field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *ErrorPassthroughRuleMutation) SetField(name string, value ent.Value) error {
+ switch name {
+ case errorpassthroughrule.FieldCreatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCreatedAt(v)
+ return nil
+ case errorpassthroughrule.FieldUpdatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUpdatedAt(v)
+ return nil
+ case errorpassthroughrule.FieldName:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetName(v)
+ return nil
+ case errorpassthroughrule.FieldEnabled:
+ v, ok := value.(bool)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetEnabled(v)
+ return nil
+ case errorpassthroughrule.FieldPriority:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetPriority(v)
+ return nil
+ case errorpassthroughrule.FieldErrorCodes:
+ v, ok := value.([]int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetErrorCodes(v)
+ return nil
+ case errorpassthroughrule.FieldKeywords:
+ v, ok := value.([]string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetKeywords(v)
+ return nil
+ case errorpassthroughrule.FieldMatchMode:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetMatchMode(v)
+ return nil
+ case errorpassthroughrule.FieldPlatforms:
+ v, ok := value.([]string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetPlatforms(v)
+ return nil
+ case errorpassthroughrule.FieldPassthroughCode:
+ v, ok := value.(bool)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetPassthroughCode(v)
+ return nil
+ case errorpassthroughrule.FieldResponseCode:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetResponseCode(v)
+ return nil
+ case errorpassthroughrule.FieldPassthroughBody:
+ v, ok := value.(bool)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetPassthroughBody(v)
+ return nil
+ case errorpassthroughrule.FieldCustomMessage:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCustomMessage(v)
+ return nil
+ case errorpassthroughrule.FieldSkipMonitoring:
+ v, ok := value.(bool)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetSkipMonitoring(v)
+ return nil
+ case errorpassthroughrule.FieldDescription:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetDescription(v)
+ return nil
}
- return oldValue.McpXMLInject, nil
+ return fmt.Errorf("unknown ErrorPassthroughRule field %s", name)
}
-// ResetMcpXMLInject resets all changes to the "mcp_xml_inject" field.
-func (m *GroupMutation) ResetMcpXMLInject() {
- m.mcp_xml_inject = nil
+// AddedFields returns all numeric fields that were incremented/decremented during
+// this mutation.
+func (m *ErrorPassthroughRuleMutation) AddedFields() []string {
+ var fields []string
+ if m.addpriority != nil {
+ fields = append(fields, errorpassthroughrule.FieldPriority)
+ }
+ if m.addresponse_code != nil {
+ fields = append(fields, errorpassthroughrule.FieldResponseCode)
+ }
+ return fields
}
-// SetSupportedModelScopes sets the "supported_model_scopes" field.
-func (m *GroupMutation) SetSupportedModelScopes(s []string) {
- m.supported_model_scopes = &s
- m.appendsupported_model_scopes = nil
+// AddedField returns the numeric value that was incremented/decremented on a field
+// with the given name. The second boolean return value indicates that this field
+// was not set, or was not defined in the schema.
+func (m *ErrorPassthroughRuleMutation) AddedField(name string) (ent.Value, bool) {
+ switch name {
+ case errorpassthroughrule.FieldPriority:
+ return m.AddedPriority()
+ case errorpassthroughrule.FieldResponseCode:
+ return m.AddedResponseCode()
+ }
+ return nil, false
}
-// SupportedModelScopes returns the value of the "supported_model_scopes" field in the mutation.
-func (m *GroupMutation) SupportedModelScopes() (r []string, exists bool) {
- v := m.supported_model_scopes
- if v == nil {
- return
+// AddField adds the value to the field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *ErrorPassthroughRuleMutation) AddField(name string, value ent.Value) error {
+ switch name {
+ case errorpassthroughrule.FieldPriority:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddPriority(v)
+ return nil
+ case errorpassthroughrule.FieldResponseCode:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddResponseCode(v)
+ return nil
}
- return *v, true
+ return fmt.Errorf("unknown ErrorPassthroughRule numeric field %s", name)
}
-// OldSupportedModelScopes returns the old "supported_model_scopes" field's value of the Group entity.
-// If the Group object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *GroupMutation) OldSupportedModelScopes(ctx context.Context) (v []string, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldSupportedModelScopes is only allowed on UpdateOne operations")
+// ClearedFields returns all nullable fields that were cleared during this
+// mutation.
+func (m *ErrorPassthroughRuleMutation) ClearedFields() []string {
+ var fields []string
+ if m.FieldCleared(errorpassthroughrule.FieldErrorCodes) {
+ fields = append(fields, errorpassthroughrule.FieldErrorCodes)
}
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldSupportedModelScopes requires an ID field in the mutation")
+ if m.FieldCleared(errorpassthroughrule.FieldKeywords) {
+ fields = append(fields, errorpassthroughrule.FieldKeywords)
}
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldSupportedModelScopes: %w", err)
+ if m.FieldCleared(errorpassthroughrule.FieldPlatforms) {
+ fields = append(fields, errorpassthroughrule.FieldPlatforms)
}
- return oldValue.SupportedModelScopes, nil
-}
-
-// AppendSupportedModelScopes adds s to the "supported_model_scopes" field.
-func (m *GroupMutation) AppendSupportedModelScopes(s []string) {
- m.appendsupported_model_scopes = append(m.appendsupported_model_scopes, s...)
-}
-
-// AppendedSupportedModelScopes returns the list of values that were appended to the "supported_model_scopes" field in this mutation.
-func (m *GroupMutation) AppendedSupportedModelScopes() ([]string, bool) {
- if len(m.appendsupported_model_scopes) == 0 {
- return nil, false
+ if m.FieldCleared(errorpassthroughrule.FieldResponseCode) {
+ fields = append(fields, errorpassthroughrule.FieldResponseCode)
}
- return m.appendsupported_model_scopes, true
+ if m.FieldCleared(errorpassthroughrule.FieldCustomMessage) {
+ fields = append(fields, errorpassthroughrule.FieldCustomMessage)
+ }
+ if m.FieldCleared(errorpassthroughrule.FieldDescription) {
+ fields = append(fields, errorpassthroughrule.FieldDescription)
+ }
+ return fields
}
-// ResetSupportedModelScopes resets all changes to the "supported_model_scopes" field.
-func (m *GroupMutation) ResetSupportedModelScopes() {
- m.supported_model_scopes = nil
- m.appendsupported_model_scopes = nil
+// FieldCleared returns a boolean indicating if a field with the given name was
+// cleared in this mutation.
+func (m *ErrorPassthroughRuleMutation) FieldCleared(name string) bool {
+ _, ok := m.clearedFields[name]
+ return ok
}
-// SetSortOrder sets the "sort_order" field.
-func (m *GroupMutation) SetSortOrder(i int) {
- m.sort_order = &i
- m.addsort_order = nil
+// ClearField clears the value of the field with the given name. It returns an
+// error if the field is not defined in the schema.
+func (m *ErrorPassthroughRuleMutation) ClearField(name string) error {
+ switch name {
+ case errorpassthroughrule.FieldErrorCodes:
+ m.ClearErrorCodes()
+ return nil
+ case errorpassthroughrule.FieldKeywords:
+ m.ClearKeywords()
+ return nil
+ case errorpassthroughrule.FieldPlatforms:
+ m.ClearPlatforms()
+ return nil
+ case errorpassthroughrule.FieldResponseCode:
+ m.ClearResponseCode()
+ return nil
+ case errorpassthroughrule.FieldCustomMessage:
+ m.ClearCustomMessage()
+ return nil
+ case errorpassthroughrule.FieldDescription:
+ m.ClearDescription()
+ return nil
+ }
+ return fmt.Errorf("unknown ErrorPassthroughRule nullable field %s", name)
}
-// SortOrder returns the value of the "sort_order" field in the mutation.
-func (m *GroupMutation) SortOrder() (r int, exists bool) {
- v := m.sort_order
- if v == nil {
- return
+// ResetField resets all changes in the mutation for the field with the given name.
+// It returns an error if the field is not defined in the schema.
+func (m *ErrorPassthroughRuleMutation) ResetField(name string) error {
+ switch name {
+ case errorpassthroughrule.FieldCreatedAt:
+ m.ResetCreatedAt()
+ return nil
+ case errorpassthroughrule.FieldUpdatedAt:
+ m.ResetUpdatedAt()
+ return nil
+ case errorpassthroughrule.FieldName:
+ m.ResetName()
+ return nil
+ case errorpassthroughrule.FieldEnabled:
+ m.ResetEnabled()
+ return nil
+ case errorpassthroughrule.FieldPriority:
+ m.ResetPriority()
+ return nil
+ case errorpassthroughrule.FieldErrorCodes:
+ m.ResetErrorCodes()
+ return nil
+ case errorpassthroughrule.FieldKeywords:
+ m.ResetKeywords()
+ return nil
+ case errorpassthroughrule.FieldMatchMode:
+ m.ResetMatchMode()
+ return nil
+ case errorpassthroughrule.FieldPlatforms:
+ m.ResetPlatforms()
+ return nil
+ case errorpassthroughrule.FieldPassthroughCode:
+ m.ResetPassthroughCode()
+ return nil
+ case errorpassthroughrule.FieldResponseCode:
+ m.ResetResponseCode()
+ return nil
+ case errorpassthroughrule.FieldPassthroughBody:
+ m.ResetPassthroughBody()
+ return nil
+ case errorpassthroughrule.FieldCustomMessage:
+ m.ResetCustomMessage()
+ return nil
+ case errorpassthroughrule.FieldSkipMonitoring:
+ m.ResetSkipMonitoring()
+ return nil
+ case errorpassthroughrule.FieldDescription:
+ m.ResetDescription()
+ return nil
}
- return *v, true
+ return fmt.Errorf("unknown ErrorPassthroughRule field %s", name)
}
-// OldSortOrder returns the old "sort_order" field's value of the Group entity.
-// If the Group object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *GroupMutation) OldSortOrder(ctx context.Context) (v int, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldSortOrder is only allowed on UpdateOne operations")
- }
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldSortOrder requires an ID field in the mutation")
- }
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldSortOrder: %w", err)
- }
- return oldValue.SortOrder, nil
+// AddedEdges returns all edge names that were set/added in this mutation.
+func (m *ErrorPassthroughRuleMutation) AddedEdges() []string {
+ edges := make([]string, 0, 0)
+ return edges
}
-// AddSortOrder adds i to the "sort_order" field.
-func (m *GroupMutation) AddSortOrder(i int) {
- if m.addsort_order != nil {
- *m.addsort_order += i
- } else {
- m.addsort_order = &i
- }
+// AddedIDs returns all IDs (to other nodes) that were added for the given edge
+// name in this mutation.
+func (m *ErrorPassthroughRuleMutation) AddedIDs(name string) []ent.Value {
+ return nil
}
-// AddedSortOrder returns the value that was added to the "sort_order" field in this mutation.
-func (m *GroupMutation) AddedSortOrder() (r int, exists bool) {
- v := m.addsort_order
- if v == nil {
- return
- }
- return *v, true
+// RemovedEdges returns all edge names that were removed in this mutation.
+func (m *ErrorPassthroughRuleMutation) RemovedEdges() []string {
+ edges := make([]string, 0, 0)
+ return edges
}
-// ResetSortOrder resets all changes to the "sort_order" field.
-func (m *GroupMutation) ResetSortOrder() {
- m.sort_order = nil
- m.addsort_order = nil
+// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
+// the given name in this mutation.
+func (m *ErrorPassthroughRuleMutation) RemovedIDs(name string) []ent.Value {
+ return nil
}
-// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field.
-func (m *GroupMutation) SetAllowMessagesDispatch(b bool) {
- m.allow_messages_dispatch = &b
+// ClearedEdges returns all edge names that were cleared in this mutation.
+func (m *ErrorPassthroughRuleMutation) ClearedEdges() []string {
+ edges := make([]string, 0, 0)
+ return edges
}
-// AllowMessagesDispatch returns the value of the "allow_messages_dispatch" field in the mutation.
-func (m *GroupMutation) AllowMessagesDispatch() (r bool, exists bool) {
- v := m.allow_messages_dispatch
- if v == nil {
- return
- }
- return *v, true
+// EdgeCleared returns a boolean which indicates if the edge with the given name
+// was cleared in this mutation.
+func (m *ErrorPassthroughRuleMutation) EdgeCleared(name string) bool {
+ return false
}
-// OldAllowMessagesDispatch returns the old "allow_messages_dispatch" field's value of the Group entity.
-// If the Group object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *GroupMutation) OldAllowMessagesDispatch(ctx context.Context) (v bool, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldAllowMessagesDispatch is only allowed on UpdateOne operations")
- }
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldAllowMessagesDispatch requires an ID field in the mutation")
- }
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldAllowMessagesDispatch: %w", err)
- }
- return oldValue.AllowMessagesDispatch, nil
+// ClearEdge clears the value of the edge with the given name. It returns an error
+// if that edge is not defined in the schema.
+func (m *ErrorPassthroughRuleMutation) ClearEdge(name string) error {
+ return fmt.Errorf("unknown ErrorPassthroughRule unique edge %s", name)
}
-// ResetAllowMessagesDispatch resets all changes to the "allow_messages_dispatch" field.
-func (m *GroupMutation) ResetAllowMessagesDispatch() {
- m.allow_messages_dispatch = nil
+// ResetEdge resets all changes to the edge with the given name in this mutation.
+// It returns an error if the edge is not defined in the schema.
+func (m *ErrorPassthroughRuleMutation) ResetEdge(name string) error {
+ return fmt.Errorf("unknown ErrorPassthroughRule edge %s", name)
}
-// SetRequireOauthOnly sets the "require_oauth_only" field.
-func (m *GroupMutation) SetRequireOauthOnly(b bool) {
- m.require_oauth_only = &b
+// GroupMutation represents an operation that mutates the Group nodes in the graph.
+type GroupMutation struct {
+ config
+ op Op
+ typ string
+ id *int64
+ created_at *time.Time
+ updated_at *time.Time
+ deleted_at *time.Time
+ name *string
+ description *string
+ rate_multiplier *float64
+ addrate_multiplier *float64
+ is_exclusive *bool
+ status *string
+ platform *string
+ subscription_type *string
+ daily_limit_usd *float64
+ adddaily_limit_usd *float64
+ weekly_limit_usd *float64
+ addweekly_limit_usd *float64
+ monthly_limit_usd *float64
+ addmonthly_limit_usd *float64
+ default_validity_days *int
+ adddefault_validity_days *int
+ image_price_1k *float64
+ addimage_price_1k *float64
+ image_price_2k *float64
+ addimage_price_2k *float64
+ image_price_4k *float64
+ addimage_price_4k *float64
+ claude_code_only *bool
+ fallback_group_id *int64
+ addfallback_group_id *int64
+ fallback_group_id_on_invalid_request *int64
+ addfallback_group_id_on_invalid_request *int64
+ model_routing *map[string][]int64
+ model_routing_enabled *bool
+ mcp_xml_inject *bool
+ supported_model_scopes *[]string
+ appendsupported_model_scopes []string
+ sort_order *int
+ addsort_order *int
+ allow_messages_dispatch *bool
+ require_oauth_only *bool
+ require_privacy_set *bool
+ default_mapped_model *string
+ messages_dispatch_model_config *domain.OpenAIMessagesDispatchModelConfig
+ clearedFields map[string]struct{}
+ api_keys map[int64]struct{}
+ removedapi_keys map[int64]struct{}
+ clearedapi_keys bool
+ redeem_codes map[int64]struct{}
+ removedredeem_codes map[int64]struct{}
+ clearedredeem_codes bool
+ subscriptions map[int64]struct{}
+ removedsubscriptions map[int64]struct{}
+ clearedsubscriptions bool
+ usage_logs map[int64]struct{}
+ removedusage_logs map[int64]struct{}
+ clearedusage_logs bool
+ accounts map[int64]struct{}
+ removedaccounts map[int64]struct{}
+ clearedaccounts bool
+ allowed_users map[int64]struct{}
+ removedallowed_users map[int64]struct{}
+ clearedallowed_users bool
+ done bool
+ oldValue func(context.Context) (*Group, error)
+ predicates []predicate.Group
}
-// RequireOauthOnly returns the value of the "require_oauth_only" field in the mutation.
-func (m *GroupMutation) RequireOauthOnly() (r bool, exists bool) {
- v := m.require_oauth_only
- if v == nil {
- return
- }
- return *v, true
-}
+var _ ent.Mutation = (*GroupMutation)(nil)
-// OldRequireOauthOnly returns the old "require_oauth_only" field's value of the Group entity.
-// If the Group object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *GroupMutation) OldRequireOauthOnly(ctx context.Context) (v bool, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldRequireOauthOnly is only allowed on UpdateOne operations")
- }
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldRequireOauthOnly requires an ID field in the mutation")
+// groupOption allows management of the mutation configuration using functional options.
+type groupOption func(*GroupMutation)
+
+// newGroupMutation creates new mutation for the Group entity.
+func newGroupMutation(c config, op Op, opts ...groupOption) *GroupMutation {
+ m := &GroupMutation{
+ config: c,
+ op: op,
+ typ: TypeGroup,
+ clearedFields: make(map[string]struct{}),
}
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldRequireOauthOnly: %w", err)
+ for _, opt := range opts {
+ opt(m)
}
- return oldValue.RequireOauthOnly, nil
+ return m
}
-// ResetRequireOauthOnly resets all changes to the "require_oauth_only" field.
-func (m *GroupMutation) ResetRequireOauthOnly() {
- m.require_oauth_only = nil
+// withGroupID sets the ID field of the mutation.
+func withGroupID(id int64) groupOption {
+ return func(m *GroupMutation) {
+ var (
+ err error
+ once sync.Once
+ value *Group
+ )
+ m.oldValue = func(ctx context.Context) (*Group, error) {
+ once.Do(func() {
+ if m.done {
+ err = errors.New("querying old values post mutation is not allowed")
+ } else {
+ value, err = m.Client().Group.Get(ctx, id)
+ }
+ })
+ return value, err
+ }
+ m.id = &id
+ }
}
-// SetRequirePrivacySet sets the "require_privacy_set" field.
-func (m *GroupMutation) SetRequirePrivacySet(b bool) {
- m.require_privacy_set = &b
+// withGroup sets the old Group of the mutation.
+func withGroup(node *Group) groupOption {
+ return func(m *GroupMutation) {
+ m.oldValue = func(context.Context) (*Group, error) {
+ return node, nil
+ }
+ m.id = &node.ID
+ }
}
-// RequirePrivacySet returns the value of the "require_privacy_set" field in the mutation.
-func (m *GroupMutation) RequirePrivacySet() (r bool, exists bool) {
- v := m.require_privacy_set
+// Client returns a new `ent.Client` from the mutation. If the mutation was
+// executed in a transaction (ent.Tx), a transactional client is returned.
+func (m GroupMutation) Client() *Client {
+ client := &Client{config: m.config}
+ client.init()
+ return client
+}
+
+// Tx returns an `ent.Tx` for mutations that were executed in transactions;
+// it returns an error otherwise.
+func (m GroupMutation) Tx() (*Tx, error) {
+ if _, ok := m.driver.(*txDriver); !ok {
+ return nil, errors.New("ent: mutation is not running in a transaction")
+ }
+ tx := &Tx{config: m.config}
+ tx.init()
+ return tx, nil
+}
+
+// ID returns the ID value in the mutation. Note that the ID is only available
+// if it was provided to the builder or after it was returned from the database.
+func (m *GroupMutation) ID() (id int64, exists bool) {
+ if m.id == nil {
+ return
+ }
+ return *m.id, true
+}
+
+// IDs queries the database and returns the entity ids that match the mutation's predicate.
+// That means, if the mutation is applied within a transaction with an isolation level such
+// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
+// or updated by the mutation.
+func (m *GroupMutation) IDs(ctx context.Context) ([]int64, error) {
+ switch {
+ case m.op.Is(OpUpdateOne | OpDeleteOne):
+ id, exists := m.ID()
+ if exists {
+ return []int64{id}, nil
+ }
+ fallthrough
+ case m.op.Is(OpUpdate | OpDelete):
+ return m.Client().Group.Query().Where(m.predicates...).IDs(ctx)
+ default:
+ return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
+ }
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (m *GroupMutation) SetCreatedAt(t time.Time) {
+ m.created_at = &t
+}
+
+// CreatedAt returns the value of the "created_at" field in the mutation.
+func (m *GroupMutation) CreatedAt() (r time.Time, exists bool) {
+ v := m.created_at
if v == nil {
return
}
return *v, true
}
-// OldRequirePrivacySet returns the old "require_privacy_set" field's value of the Group entity.
+// OldCreatedAt returns the old "created_at" field's value of the Group entity.
// If the Group object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *GroupMutation) OldRequirePrivacySet(ctx context.Context) (v bool, err error) {
+func (m *GroupMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldRequirePrivacySet is only allowed on UpdateOne operations")
+ return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldRequirePrivacySet requires an ID field in the mutation")
+ return v, errors.New("OldCreatedAt requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldRequirePrivacySet: %w", err)
+ return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err)
}
- return oldValue.RequirePrivacySet, nil
+ return oldValue.CreatedAt, nil
}
-// ResetRequirePrivacySet resets all changes to the "require_privacy_set" field.
-func (m *GroupMutation) ResetRequirePrivacySet() {
- m.require_privacy_set = nil
+// ResetCreatedAt resets all changes to the "created_at" field.
+func (m *GroupMutation) ResetCreatedAt() {
+ m.created_at = nil
}
-// SetDefaultMappedModel sets the "default_mapped_model" field.
-func (m *GroupMutation) SetDefaultMappedModel(s string) {
- m.default_mapped_model = &s
+// SetUpdatedAt sets the "updated_at" field.
+func (m *GroupMutation) SetUpdatedAt(t time.Time) {
+ m.updated_at = &t
}
-// DefaultMappedModel returns the value of the "default_mapped_model" field in the mutation.
-func (m *GroupMutation) DefaultMappedModel() (r string, exists bool) {
- v := m.default_mapped_model
+// UpdatedAt returns the value of the "updated_at" field in the mutation.
+func (m *GroupMutation) UpdatedAt() (r time.Time, exists bool) {
+ v := m.updated_at
if v == nil {
return
}
return *v, true
}
-// OldDefaultMappedModel returns the old "default_mapped_model" field's value of the Group entity.
+// OldUpdatedAt returns the old "updated_at" field's value of the Group entity.
// If the Group object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *GroupMutation) OldDefaultMappedModel(ctx context.Context) (v string, err error) {
+func (m *GroupMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldDefaultMappedModel is only allowed on UpdateOne operations")
+ return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldDefaultMappedModel requires an ID field in the mutation")
+ return v, errors.New("OldUpdatedAt requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldDefaultMappedModel: %w", err)
+ return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err)
}
- return oldValue.DefaultMappedModel, nil
+ return oldValue.UpdatedAt, nil
}
-// ResetDefaultMappedModel resets all changes to the "default_mapped_model" field.
-func (m *GroupMutation) ResetDefaultMappedModel() {
- m.default_mapped_model = nil
+// ResetUpdatedAt resets all changes to the "updated_at" field.
+func (m *GroupMutation) ResetUpdatedAt() {
+ m.updated_at = nil
}
-// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field.
-func (m *GroupMutation) SetMessagesDispatchModelConfig(damdmc domain.OpenAIMessagesDispatchModelConfig) {
- m.messages_dispatch_model_config = &damdmc
+// SetDeletedAt sets the "deleted_at" field.
+func (m *GroupMutation) SetDeletedAt(t time.Time) {
+ m.deleted_at = &t
}
-// MessagesDispatchModelConfig returns the value of the "messages_dispatch_model_config" field in the mutation.
-func (m *GroupMutation) MessagesDispatchModelConfig() (r domain.OpenAIMessagesDispatchModelConfig, exists bool) {
- v := m.messages_dispatch_model_config
+// DeletedAt returns the value of the "deleted_at" field in the mutation.
+func (m *GroupMutation) DeletedAt() (r time.Time, exists bool) {
+ v := m.deleted_at
if v == nil {
return
}
return *v, true
}
-// OldMessagesDispatchModelConfig returns the old "messages_dispatch_model_config" field's value of the Group entity.
+// OldDeletedAt returns the old "deleted_at" field's value of the Group entity.
// If the Group object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *GroupMutation) OldMessagesDispatchModelConfig(ctx context.Context) (v domain.OpenAIMessagesDispatchModelConfig, err error) {
+func (m *GroupMutation) OldDeletedAt(ctx context.Context) (v *time.Time, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldMessagesDispatchModelConfig is only allowed on UpdateOne operations")
+ return v, errors.New("OldDeletedAt is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldMessagesDispatchModelConfig requires an ID field in the mutation")
+ return v, errors.New("OldDeletedAt requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldMessagesDispatchModelConfig: %w", err)
+ return v, fmt.Errorf("querying old value for OldDeletedAt: %w", err)
}
- return oldValue.MessagesDispatchModelConfig, nil
+ return oldValue.DeletedAt, nil
}
-// ResetMessagesDispatchModelConfig resets all changes to the "messages_dispatch_model_config" field.
-func (m *GroupMutation) ResetMessagesDispatchModelConfig() {
- m.messages_dispatch_model_config = nil
+// ClearDeletedAt clears the value of the "deleted_at" field.
+func (m *GroupMutation) ClearDeletedAt() {
+ m.deleted_at = nil
+ m.clearedFields[group.FieldDeletedAt] = struct{}{}
}
-// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
-func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) {
- if m.api_keys == nil {
- m.api_keys = make(map[int64]struct{})
- }
- for i := range ids {
- m.api_keys[ids[i]] = struct{}{}
- }
+// DeletedAtCleared returns if the "deleted_at" field was cleared in this mutation.
+func (m *GroupMutation) DeletedAtCleared() bool {
+ _, ok := m.clearedFields[group.FieldDeletedAt]
+ return ok
}
-// ClearAPIKeys clears the "api_keys" edge to the APIKey entity.
-func (m *GroupMutation) ClearAPIKeys() {
- m.clearedapi_keys = true
+// ResetDeletedAt resets all changes to the "deleted_at" field.
+func (m *GroupMutation) ResetDeletedAt() {
+ m.deleted_at = nil
+ delete(m.clearedFields, group.FieldDeletedAt)
}
-// APIKeysCleared reports if the "api_keys" edge to the APIKey entity was cleared.
-func (m *GroupMutation) APIKeysCleared() bool {
- return m.clearedapi_keys
+// SetName sets the "name" field.
+func (m *GroupMutation) SetName(s string) {
+ m.name = &s
}
-// RemoveAPIKeyIDs removes the "api_keys" edge to the APIKey entity by IDs.
-func (m *GroupMutation) RemoveAPIKeyIDs(ids ...int64) {
- if m.removedapi_keys == nil {
- m.removedapi_keys = make(map[int64]struct{})
- }
- for i := range ids {
- delete(m.api_keys, ids[i])
- m.removedapi_keys[ids[i]] = struct{}{}
+// Name returns the value of the "name" field in the mutation.
+func (m *GroupMutation) Name() (r string, exists bool) {
+ v := m.name
+ if v == nil {
+ return
}
+ return *v, true
}
-// RemovedAPIKeys returns the removed IDs of the "api_keys" edge to the APIKey entity.
-func (m *GroupMutation) RemovedAPIKeysIDs() (ids []int64) {
- for id := range m.removedapi_keys {
- ids = append(ids, id)
+// OldName returns the old "name" field's value of the Group entity.
+// If the Group object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *GroupMutation) OldName(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldName is only allowed on UpdateOne operations")
}
- return
-}
-
-// APIKeysIDs returns the "api_keys" edge IDs in the mutation.
-func (m *GroupMutation) APIKeysIDs() (ids []int64) {
- for id := range m.api_keys {
- ids = append(ids, id)
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldName requires an ID field in the mutation")
}
- return
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldName: %w", err)
+ }
+ return oldValue.Name, nil
}
-// ResetAPIKeys resets all changes to the "api_keys" edge.
-func (m *GroupMutation) ResetAPIKeys() {
- m.api_keys = nil
- m.clearedapi_keys = false
- m.removedapi_keys = nil
+// ResetName resets all changes to the "name" field.
+func (m *GroupMutation) ResetName() {
+ m.name = nil
}
-// AddRedeemCodeIDs adds the "redeem_codes" edge to the RedeemCode entity by ids.
-func (m *GroupMutation) AddRedeemCodeIDs(ids ...int64) {
- if m.redeem_codes == nil {
- m.redeem_codes = make(map[int64]struct{})
- }
- for i := range ids {
- m.redeem_codes[ids[i]] = struct{}{}
- }
+// SetDescription sets the "description" field.
+func (m *GroupMutation) SetDescription(s string) {
+ m.description = &s
}
-// ClearRedeemCodes clears the "redeem_codes" edge to the RedeemCode entity.
-func (m *GroupMutation) ClearRedeemCodes() {
- m.clearedredeem_codes = true
-}
-
-// RedeemCodesCleared reports if the "redeem_codes" edge to the RedeemCode entity was cleared.
-func (m *GroupMutation) RedeemCodesCleared() bool {
- return m.clearedredeem_codes
+// Description returns the value of the "description" field in the mutation.
+func (m *GroupMutation) Description() (r string, exists bool) {
+ v := m.description
+ if v == nil {
+ return
+ }
+ return *v, true
}
-// RemoveRedeemCodeIDs removes the "redeem_codes" edge to the RedeemCode entity by IDs.
-func (m *GroupMutation) RemoveRedeemCodeIDs(ids ...int64) {
- if m.removedredeem_codes == nil {
- m.removedredeem_codes = make(map[int64]struct{})
+// OldDescription returns the old "description" field's value of the Group entity.
+// If the Group object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *GroupMutation) OldDescription(ctx context.Context) (v *string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldDescription is only allowed on UpdateOne operations")
}
- for i := range ids {
- delete(m.redeem_codes, ids[i])
- m.removedredeem_codes[ids[i]] = struct{}{}
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldDescription requires an ID field in the mutation")
}
-}
-
-// RemovedRedeemCodes returns the removed IDs of the "redeem_codes" edge to the RedeemCode entity.
-func (m *GroupMutation) RemovedRedeemCodesIDs() (ids []int64) {
- for id := range m.removedredeem_codes {
- ids = append(ids, id)
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldDescription: %w", err)
}
- return
+ return oldValue.Description, nil
}
-// RedeemCodesIDs returns the "redeem_codes" edge IDs in the mutation.
-func (m *GroupMutation) RedeemCodesIDs() (ids []int64) {
- for id := range m.redeem_codes {
- ids = append(ids, id)
- }
- return
+// ClearDescription clears the value of the "description" field.
+func (m *GroupMutation) ClearDescription() {
+ m.description = nil
+ m.clearedFields[group.FieldDescription] = struct{}{}
}
-// ResetRedeemCodes resets all changes to the "redeem_codes" edge.
-func (m *GroupMutation) ResetRedeemCodes() {
- m.redeem_codes = nil
- m.clearedredeem_codes = false
- m.removedredeem_codes = nil
+// DescriptionCleared returns if the "description" field was cleared in this mutation.
+func (m *GroupMutation) DescriptionCleared() bool {
+ _, ok := m.clearedFields[group.FieldDescription]
+ return ok
}
-// AddSubscriptionIDs adds the "subscriptions" edge to the UserSubscription entity by ids.
-func (m *GroupMutation) AddSubscriptionIDs(ids ...int64) {
- if m.subscriptions == nil {
- m.subscriptions = make(map[int64]struct{})
- }
- for i := range ids {
- m.subscriptions[ids[i]] = struct{}{}
- }
+// ResetDescription resets all changes to the "description" field.
+func (m *GroupMutation) ResetDescription() {
+ m.description = nil
+ delete(m.clearedFields, group.FieldDescription)
}
-// ClearSubscriptions clears the "subscriptions" edge to the UserSubscription entity.
-func (m *GroupMutation) ClearSubscriptions() {
- m.clearedsubscriptions = true
+// SetRateMultiplier sets the "rate_multiplier" field.
+func (m *GroupMutation) SetRateMultiplier(f float64) {
+ m.rate_multiplier = &f
+ m.addrate_multiplier = nil
}
-// SubscriptionsCleared reports if the "subscriptions" edge to the UserSubscription entity was cleared.
-func (m *GroupMutation) SubscriptionsCleared() bool {
- return m.clearedsubscriptions
+// RateMultiplier returns the value of the "rate_multiplier" field in the mutation.
+func (m *GroupMutation) RateMultiplier() (r float64, exists bool) {
+ v := m.rate_multiplier
+ if v == nil {
+ return
+ }
+ return *v, true
}
-// RemoveSubscriptionIDs removes the "subscriptions" edge to the UserSubscription entity by IDs.
-func (m *GroupMutation) RemoveSubscriptionIDs(ids ...int64) {
- if m.removedsubscriptions == nil {
- m.removedsubscriptions = make(map[int64]struct{})
+// OldRateMultiplier returns the old "rate_multiplier" field's value of the Group entity.
+// If the Group object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *GroupMutation) OldRateMultiplier(ctx context.Context) (v float64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldRateMultiplier is only allowed on UpdateOne operations")
}
- for i := range ids {
- delete(m.subscriptions, ids[i])
- m.removedsubscriptions[ids[i]] = struct{}{}
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldRateMultiplier requires an ID field in the mutation")
}
-}
-
-// RemovedSubscriptions returns the removed IDs of the "subscriptions" edge to the UserSubscription entity.
-func (m *GroupMutation) RemovedSubscriptionsIDs() (ids []int64) {
- for id := range m.removedsubscriptions {
- ids = append(ids, id)
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldRateMultiplier: %w", err)
}
- return
+ return oldValue.RateMultiplier, nil
}
-// SubscriptionsIDs returns the "subscriptions" edge IDs in the mutation.
-func (m *GroupMutation) SubscriptionsIDs() (ids []int64) {
- for id := range m.subscriptions {
- ids = append(ids, id)
+// AddRateMultiplier adds f to the "rate_multiplier" field.
+func (m *GroupMutation) AddRateMultiplier(f float64) {
+ if m.addrate_multiplier != nil {
+ *m.addrate_multiplier += f
+ } else {
+ m.addrate_multiplier = &f
}
- return
}
-// ResetSubscriptions resets all changes to the "subscriptions" edge.
-func (m *GroupMutation) ResetSubscriptions() {
- m.subscriptions = nil
- m.clearedsubscriptions = false
- m.removedsubscriptions = nil
+// AddedRateMultiplier returns the value that was added to the "rate_multiplier" field in this mutation.
+func (m *GroupMutation) AddedRateMultiplier() (r float64, exists bool) {
+ v := m.addrate_multiplier
+ if v == nil {
+ return
+ }
+ return *v, true
}
-// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by ids.
-func (m *GroupMutation) AddUsageLogIDs(ids ...int64) {
- if m.usage_logs == nil {
- m.usage_logs = make(map[int64]struct{})
- }
- for i := range ids {
- m.usage_logs[ids[i]] = struct{}{}
- }
+// ResetRateMultiplier resets all changes to the "rate_multiplier" field.
+func (m *GroupMutation) ResetRateMultiplier() {
+ m.rate_multiplier = nil
+ m.addrate_multiplier = nil
}
-// ClearUsageLogs clears the "usage_logs" edge to the UsageLog entity.
-func (m *GroupMutation) ClearUsageLogs() {
- m.clearedusage_logs = true
+// SetIsExclusive sets the "is_exclusive" field.
+func (m *GroupMutation) SetIsExclusive(b bool) {
+ m.is_exclusive = &b
}
-// UsageLogsCleared reports if the "usage_logs" edge to the UsageLog entity was cleared.
-func (m *GroupMutation) UsageLogsCleared() bool {
- return m.clearedusage_logs
+// IsExclusive returns the value of the "is_exclusive" field in the mutation.
+func (m *GroupMutation) IsExclusive() (r bool, exists bool) {
+ v := m.is_exclusive
+ if v == nil {
+ return
+ }
+ return *v, true
}
-// RemoveUsageLogIDs removes the "usage_logs" edge to the UsageLog entity by IDs.
-func (m *GroupMutation) RemoveUsageLogIDs(ids ...int64) {
- if m.removedusage_logs == nil {
- m.removedusage_logs = make(map[int64]struct{})
+// OldIsExclusive returns the old "is_exclusive" field's value of the Group entity.
+// If the Group object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *GroupMutation) OldIsExclusive(ctx context.Context) (v bool, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldIsExclusive is only allowed on UpdateOne operations")
}
- for i := range ids {
- delete(m.usage_logs, ids[i])
- m.removedusage_logs[ids[i]] = struct{}{}
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldIsExclusive requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldIsExclusive: %w", err)
}
+ return oldValue.IsExclusive, nil
}
-// RemovedUsageLogs returns the removed IDs of the "usage_logs" edge to the UsageLog entity.
-func (m *GroupMutation) RemovedUsageLogsIDs() (ids []int64) {
- for id := range m.removedusage_logs {
- ids = append(ids, id)
- }
- return
+// ResetIsExclusive resets all changes to the "is_exclusive" field.
+func (m *GroupMutation) ResetIsExclusive() {
+ m.is_exclusive = nil
}
-// UsageLogsIDs returns the "usage_logs" edge IDs in the mutation.
-func (m *GroupMutation) UsageLogsIDs() (ids []int64) {
- for id := range m.usage_logs {
- ids = append(ids, id)
- }
- return
+// SetStatus sets the "status" field.
+func (m *GroupMutation) SetStatus(s string) {
+ m.status = &s
}
-// ResetUsageLogs resets all changes to the "usage_logs" edge.
-func (m *GroupMutation) ResetUsageLogs() {
- m.usage_logs = nil
- m.clearedusage_logs = false
- m.removedusage_logs = nil
+// Status returns the value of the "status" field in the mutation.
+func (m *GroupMutation) Status() (r string, exists bool) {
+ v := m.status
+ if v == nil {
+ return
+ }
+ return *v, true
}
-// AddAccountIDs adds the "accounts" edge to the Account entity by ids.
-func (m *GroupMutation) AddAccountIDs(ids ...int64) {
- if m.accounts == nil {
- m.accounts = make(map[int64]struct{})
+// OldStatus returns the old "status" field's value of the Group entity.
+// If the Group object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *GroupMutation) OldStatus(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldStatus is only allowed on UpdateOne operations")
}
- for i := range ids {
- m.accounts[ids[i]] = struct{}{}
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldStatus requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldStatus: %w", err)
}
+ return oldValue.Status, nil
}
-// ClearAccounts clears the "accounts" edge to the Account entity.
-func (m *GroupMutation) ClearAccounts() {
- m.clearedaccounts = true
+// ResetStatus resets all changes to the "status" field.
+func (m *GroupMutation) ResetStatus() {
+ m.status = nil
}
-// AccountsCleared reports if the "accounts" edge to the Account entity was cleared.
-func (m *GroupMutation) AccountsCleared() bool {
- return m.clearedaccounts
+// SetPlatform sets the "platform" field.
+func (m *GroupMutation) SetPlatform(s string) {
+ m.platform = &s
}
-// RemoveAccountIDs removes the "accounts" edge to the Account entity by IDs.
-func (m *GroupMutation) RemoveAccountIDs(ids ...int64) {
- if m.removedaccounts == nil {
- m.removedaccounts = make(map[int64]struct{})
- }
- for i := range ids {
- delete(m.accounts, ids[i])
- m.removedaccounts[ids[i]] = struct{}{}
+// Platform returns the value of the "platform" field in the mutation.
+func (m *GroupMutation) Platform() (r string, exists bool) {
+ v := m.platform
+ if v == nil {
+ return
}
+ return *v, true
}
-// RemovedAccounts returns the removed IDs of the "accounts" edge to the Account entity.
-func (m *GroupMutation) RemovedAccountsIDs() (ids []int64) {
- for id := range m.removedaccounts {
- ids = append(ids, id)
+// OldPlatform returns the old "platform" field's value of the Group entity.
+// If the Group object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *GroupMutation) OldPlatform(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldPlatform is only allowed on UpdateOne operations")
}
- return
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldPlatform requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldPlatform: %w", err)
+ }
+ return oldValue.Platform, nil
}
-// AccountsIDs returns the "accounts" edge IDs in the mutation.
-func (m *GroupMutation) AccountsIDs() (ids []int64) {
- for id := range m.accounts {
- ids = append(ids, id)
- }
- return
+// ResetPlatform resets all changes to the "platform" field.
+func (m *GroupMutation) ResetPlatform() {
+ m.platform = nil
}
-// ResetAccounts resets all changes to the "accounts" edge.
-func (m *GroupMutation) ResetAccounts() {
- m.accounts = nil
- m.clearedaccounts = false
- m.removedaccounts = nil
+// SetSubscriptionType sets the "subscription_type" field.
+func (m *GroupMutation) SetSubscriptionType(s string) {
+ m.subscription_type = &s
}
-// AddAllowedUserIDs adds the "allowed_users" edge to the User entity by ids.
-func (m *GroupMutation) AddAllowedUserIDs(ids ...int64) {
- if m.allowed_users == nil {
- m.allowed_users = make(map[int64]struct{})
+// SubscriptionType returns the value of the "subscription_type" field in the mutation.
+func (m *GroupMutation) SubscriptionType() (r string, exists bool) {
+ v := m.subscription_type
+ if v == nil {
+ return
}
- for i := range ids {
- m.allowed_users[ids[i]] = struct{}{}
+ return *v, true
+}
+
+// OldSubscriptionType returns the old "subscription_type" field's value of the Group entity.
+// If the Group object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *GroupMutation) OldSubscriptionType(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldSubscriptionType is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldSubscriptionType requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldSubscriptionType: %w", err)
}
+ return oldValue.SubscriptionType, nil
}
-// ClearAllowedUsers clears the "allowed_users" edge to the User entity.
-func (m *GroupMutation) ClearAllowedUsers() {
- m.clearedallowed_users = true
+// ResetSubscriptionType resets all changes to the "subscription_type" field.
+func (m *GroupMutation) ResetSubscriptionType() {
+ m.subscription_type = nil
}
-// AllowedUsersCleared reports if the "allowed_users" edge to the User entity was cleared.
-func (m *GroupMutation) AllowedUsersCleared() bool {
- return m.clearedallowed_users
+// SetDailyLimitUsd sets the "daily_limit_usd" field.
+func (m *GroupMutation) SetDailyLimitUsd(f float64) {
+ m.daily_limit_usd = &f
+ m.adddaily_limit_usd = nil
}
-// RemoveAllowedUserIDs removes the "allowed_users" edge to the User entity by IDs.
-func (m *GroupMutation) RemoveAllowedUserIDs(ids ...int64) {
- if m.removedallowed_users == nil {
- m.removedallowed_users = make(map[int64]struct{})
- }
- for i := range ids {
- delete(m.allowed_users, ids[i])
- m.removedallowed_users[ids[i]] = struct{}{}
+// DailyLimitUsd returns the value of the "daily_limit_usd" field in the mutation.
+func (m *GroupMutation) DailyLimitUsd() (r float64, exists bool) {
+ v := m.daily_limit_usd
+ if v == nil {
+ return
}
+ return *v, true
}
-// RemovedAllowedUsers returns the removed IDs of the "allowed_users" edge to the User entity.
-func (m *GroupMutation) RemovedAllowedUsersIDs() (ids []int64) {
- for id := range m.removedallowed_users {
- ids = append(ids, id)
+// OldDailyLimitUsd returns the old "daily_limit_usd" field's value of the Group entity.
+// If the Group object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *GroupMutation) OldDailyLimitUsd(ctx context.Context) (v *float64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldDailyLimitUsd is only allowed on UpdateOne operations")
}
- return
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldDailyLimitUsd requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldDailyLimitUsd: %w", err)
+ }
+ return oldValue.DailyLimitUsd, nil
}
-// AllowedUsersIDs returns the "allowed_users" edge IDs in the mutation.
-func (m *GroupMutation) AllowedUsersIDs() (ids []int64) {
- for id := range m.allowed_users {
- ids = append(ids, id)
+// AddDailyLimitUsd adds f to the "daily_limit_usd" field.
+func (m *GroupMutation) AddDailyLimitUsd(f float64) {
+ if m.adddaily_limit_usd != nil {
+ *m.adddaily_limit_usd += f
+ } else {
+ m.adddaily_limit_usd = &f
}
- return
}
-// ResetAllowedUsers resets all changes to the "allowed_users" edge.
-func (m *GroupMutation) ResetAllowedUsers() {
- m.allowed_users = nil
- m.clearedallowed_users = false
- m.removedallowed_users = nil
+// AddedDailyLimitUsd returns the value that was added to the "daily_limit_usd" field in this mutation.
+func (m *GroupMutation) AddedDailyLimitUsd() (r float64, exists bool) {
+ v := m.adddaily_limit_usd
+ if v == nil {
+ return
+ }
+ return *v, true
}
-// Where appends a list predicates to the GroupMutation builder.
-func (m *GroupMutation) Where(ps ...predicate.Group) {
- m.predicates = append(m.predicates, ps...)
+// ClearDailyLimitUsd clears the value of the "daily_limit_usd" field.
+func (m *GroupMutation) ClearDailyLimitUsd() {
+ m.daily_limit_usd = nil
+ m.adddaily_limit_usd = nil
+ m.clearedFields[group.FieldDailyLimitUsd] = struct{}{}
}
-// WhereP appends storage-level predicates to the GroupMutation builder. Using this method,
-// users can use type-assertion to append predicates that do not depend on any generated package.
-func (m *GroupMutation) WhereP(ps ...func(*sql.Selector)) {
- p := make([]predicate.Group, len(ps))
- for i := range ps {
- p[i] = ps[i]
- }
- m.Where(p...)
+// DailyLimitUsdCleared returns if the "daily_limit_usd" field was cleared in this mutation.
+func (m *GroupMutation) DailyLimitUsdCleared() bool {
+ _, ok := m.clearedFields[group.FieldDailyLimitUsd]
+ return ok
}
-// Op returns the operation name.
-func (m *GroupMutation) Op() Op {
- return m.op
+// ResetDailyLimitUsd resets all changes to the "daily_limit_usd" field.
+func (m *GroupMutation) ResetDailyLimitUsd() {
+ m.daily_limit_usd = nil
+ m.adddaily_limit_usd = nil
+ delete(m.clearedFields, group.FieldDailyLimitUsd)
}
-// SetOp allows setting the mutation operation.
-func (m *GroupMutation) SetOp(op Op) {
- m.op = op
+// SetWeeklyLimitUsd sets the "weekly_limit_usd" field.
+func (m *GroupMutation) SetWeeklyLimitUsd(f float64) {
+ m.weekly_limit_usd = &f
+ m.addweekly_limit_usd = nil
}
-// Type returns the node type of this mutation (Group).
-func (m *GroupMutation) Type() string {
- return m.typ
+// WeeklyLimitUsd returns the value of the "weekly_limit_usd" field in the mutation.
+func (m *GroupMutation) WeeklyLimitUsd() (r float64, exists bool) {
+ v := m.weekly_limit_usd
+ if v == nil {
+ return
+ }
+ return *v, true
}
-// Fields returns all fields that were changed during this mutation. Note that in
-// order to get all numeric fields that were incremented/decremented, call
-// AddedFields().
-func (m *GroupMutation) Fields() []string {
- fields := make([]string, 0, 30)
- if m.created_at != nil {
- fields = append(fields, group.FieldCreatedAt)
- }
- if m.updated_at != nil {
- fields = append(fields, group.FieldUpdatedAt)
- }
- if m.deleted_at != nil {
- fields = append(fields, group.FieldDeletedAt)
- }
- if m.name != nil {
- fields = append(fields, group.FieldName)
- }
- if m.description != nil {
- fields = append(fields, group.FieldDescription)
- }
- if m.rate_multiplier != nil {
- fields = append(fields, group.FieldRateMultiplier)
- }
- if m.is_exclusive != nil {
- fields = append(fields, group.FieldIsExclusive)
- }
- if m.status != nil {
- fields = append(fields, group.FieldStatus)
- }
- if m.platform != nil {
- fields = append(fields, group.FieldPlatform)
- }
- if m.subscription_type != nil {
- fields = append(fields, group.FieldSubscriptionType)
- }
- if m.daily_limit_usd != nil {
- fields = append(fields, group.FieldDailyLimitUsd)
- }
- if m.weekly_limit_usd != nil {
- fields = append(fields, group.FieldWeeklyLimitUsd)
+// OldWeeklyLimitUsd returns the old "weekly_limit_usd" field's value of the Group entity.
+// If the Group object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *GroupMutation) OldWeeklyLimitUsd(ctx context.Context) (v *float64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldWeeklyLimitUsd is only allowed on UpdateOne operations")
}
- if m.monthly_limit_usd != nil {
- fields = append(fields, group.FieldMonthlyLimitUsd)
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldWeeklyLimitUsd requires an ID field in the mutation")
}
- if m.default_validity_days != nil {
- fields = append(fields, group.FieldDefaultValidityDays)
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldWeeklyLimitUsd: %w", err)
}
- if m.image_price_1k != nil {
- fields = append(fields, group.FieldImagePrice1k)
+ return oldValue.WeeklyLimitUsd, nil
+}
+
+// AddWeeklyLimitUsd adds f to the "weekly_limit_usd" field.
+func (m *GroupMutation) AddWeeklyLimitUsd(f float64) {
+ if m.addweekly_limit_usd != nil {
+ *m.addweekly_limit_usd += f
+ } else {
+ m.addweekly_limit_usd = &f
}
- if m.image_price_2k != nil {
- fields = append(fields, group.FieldImagePrice2k)
+}
+
+// AddedWeeklyLimitUsd returns the value that was added to the "weekly_limit_usd" field in this mutation.
+func (m *GroupMutation) AddedWeeklyLimitUsd() (r float64, exists bool) {
+ v := m.addweekly_limit_usd
+ if v == nil {
+ return
}
- if m.image_price_4k != nil {
- fields = append(fields, group.FieldImagePrice4k)
- }
- if m.claude_code_only != nil {
- fields = append(fields, group.FieldClaudeCodeOnly)
- }
- if m.fallback_group_id != nil {
- fields = append(fields, group.FieldFallbackGroupID)
+ return *v, true
+}
+
+// ClearWeeklyLimitUsd clears the value of the "weekly_limit_usd" field.
+func (m *GroupMutation) ClearWeeklyLimitUsd() {
+ m.weekly_limit_usd = nil
+ m.addweekly_limit_usd = nil
+ m.clearedFields[group.FieldWeeklyLimitUsd] = struct{}{}
+}
+
+// WeeklyLimitUsdCleared returns if the "weekly_limit_usd" field was cleared in this mutation.
+func (m *GroupMutation) WeeklyLimitUsdCleared() bool {
+ _, ok := m.clearedFields[group.FieldWeeklyLimitUsd]
+ return ok
+}
+
+// ResetWeeklyLimitUsd resets all changes to the "weekly_limit_usd" field.
+func (m *GroupMutation) ResetWeeklyLimitUsd() {
+ m.weekly_limit_usd = nil
+ m.addweekly_limit_usd = nil
+ delete(m.clearedFields, group.FieldWeeklyLimitUsd)
+}
+
+// SetMonthlyLimitUsd sets the "monthly_limit_usd" field.
+func (m *GroupMutation) SetMonthlyLimitUsd(f float64) {
+ m.monthly_limit_usd = &f
+ m.addmonthly_limit_usd = nil
+}
+
+// MonthlyLimitUsd returns the value of the "monthly_limit_usd" field in the mutation.
+func (m *GroupMutation) MonthlyLimitUsd() (r float64, exists bool) {
+ v := m.monthly_limit_usd
+ if v == nil {
+ return
}
- if m.fallback_group_id_on_invalid_request != nil {
- fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest)
+ return *v, true
+}
+
+// OldMonthlyLimitUsd returns the old "monthly_limit_usd" field's value of the Group entity.
+// If the Group object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *GroupMutation) OldMonthlyLimitUsd(ctx context.Context) (v *float64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldMonthlyLimitUsd is only allowed on UpdateOne operations")
}
- if m.model_routing != nil {
- fields = append(fields, group.FieldModelRouting)
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldMonthlyLimitUsd requires an ID field in the mutation")
}
- if m.model_routing_enabled != nil {
- fields = append(fields, group.FieldModelRoutingEnabled)
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldMonthlyLimitUsd: %w", err)
}
- if m.mcp_xml_inject != nil {
- fields = append(fields, group.FieldMcpXMLInject)
+ return oldValue.MonthlyLimitUsd, nil
+}
+
+// AddMonthlyLimitUsd adds f to the "monthly_limit_usd" field.
+func (m *GroupMutation) AddMonthlyLimitUsd(f float64) {
+ if m.addmonthly_limit_usd != nil {
+ *m.addmonthly_limit_usd += f
+ } else {
+ m.addmonthly_limit_usd = &f
}
- if m.supported_model_scopes != nil {
- fields = append(fields, group.FieldSupportedModelScopes)
+}
+
+// AddedMonthlyLimitUsd returns the value that was added to the "monthly_limit_usd" field in this mutation.
+func (m *GroupMutation) AddedMonthlyLimitUsd() (r float64, exists bool) {
+ v := m.addmonthly_limit_usd
+ if v == nil {
+ return
}
- if m.sort_order != nil {
- fields = append(fields, group.FieldSortOrder)
+ return *v, true
+}
+
+// ClearMonthlyLimitUsd clears the value of the "monthly_limit_usd" field.
+func (m *GroupMutation) ClearMonthlyLimitUsd() {
+ m.monthly_limit_usd = nil
+ m.addmonthly_limit_usd = nil
+ m.clearedFields[group.FieldMonthlyLimitUsd] = struct{}{}
+}
+
+// MonthlyLimitUsdCleared returns if the "monthly_limit_usd" field was cleared in this mutation.
+func (m *GroupMutation) MonthlyLimitUsdCleared() bool {
+ _, ok := m.clearedFields[group.FieldMonthlyLimitUsd]
+ return ok
+}
+
+// ResetMonthlyLimitUsd resets all changes to the "monthly_limit_usd" field.
+func (m *GroupMutation) ResetMonthlyLimitUsd() {
+ m.monthly_limit_usd = nil
+ m.addmonthly_limit_usd = nil
+ delete(m.clearedFields, group.FieldMonthlyLimitUsd)
+}
+
+// SetDefaultValidityDays sets the "default_validity_days" field.
+func (m *GroupMutation) SetDefaultValidityDays(i int) {
+ m.default_validity_days = &i
+ m.adddefault_validity_days = nil
+}
+
+// DefaultValidityDays returns the value of the "default_validity_days" field in the mutation.
+func (m *GroupMutation) DefaultValidityDays() (r int, exists bool) {
+ v := m.default_validity_days
+ if v == nil {
+ return
}
- if m.allow_messages_dispatch != nil {
- fields = append(fields, group.FieldAllowMessagesDispatch)
+ return *v, true
+}
+
+// OldDefaultValidityDays returns the old "default_validity_days" field's value of the Group entity.
+// If the Group object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *GroupMutation) OldDefaultValidityDays(ctx context.Context) (v int, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldDefaultValidityDays is only allowed on UpdateOne operations")
}
- if m.require_oauth_only != nil {
- fields = append(fields, group.FieldRequireOauthOnly)
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldDefaultValidityDays requires an ID field in the mutation")
}
- if m.require_privacy_set != nil {
- fields = append(fields, group.FieldRequirePrivacySet)
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldDefaultValidityDays: %w", err)
}
- if m.default_mapped_model != nil {
- fields = append(fields, group.FieldDefaultMappedModel)
+ return oldValue.DefaultValidityDays, nil
+}
+
+// AddDefaultValidityDays adds i to the "default_validity_days" field.
+func (m *GroupMutation) AddDefaultValidityDays(i int) {
+ if m.adddefault_validity_days != nil {
+ *m.adddefault_validity_days += i
+ } else {
+ m.adddefault_validity_days = &i
}
- if m.messages_dispatch_model_config != nil {
- fields = append(fields, group.FieldMessagesDispatchModelConfig)
+}
+
+// AddedDefaultValidityDays returns the value that was added to the "default_validity_days" field in this mutation.
+func (m *GroupMutation) AddedDefaultValidityDays() (r int, exists bool) {
+ v := m.adddefault_validity_days
+ if v == nil {
+ return
}
- return fields
+ return *v, true
}
-// Field returns the value of a field with the given name. The second boolean
-// return value indicates that this field was not set, or was not defined in the
-// schema.
-func (m *GroupMutation) Field(name string) (ent.Value, bool) {
- switch name {
- case group.FieldCreatedAt:
- return m.CreatedAt()
- case group.FieldUpdatedAt:
- return m.UpdatedAt()
- case group.FieldDeletedAt:
- return m.DeletedAt()
- case group.FieldName:
- return m.Name()
- case group.FieldDescription:
- return m.Description()
- case group.FieldRateMultiplier:
- return m.RateMultiplier()
- case group.FieldIsExclusive:
- return m.IsExclusive()
- case group.FieldStatus:
- return m.Status()
- case group.FieldPlatform:
- return m.Platform()
- case group.FieldSubscriptionType:
- return m.SubscriptionType()
- case group.FieldDailyLimitUsd:
- return m.DailyLimitUsd()
- case group.FieldWeeklyLimitUsd:
- return m.WeeklyLimitUsd()
- case group.FieldMonthlyLimitUsd:
- return m.MonthlyLimitUsd()
- case group.FieldDefaultValidityDays:
- return m.DefaultValidityDays()
- case group.FieldImagePrice1k:
- return m.ImagePrice1k()
- case group.FieldImagePrice2k:
- return m.ImagePrice2k()
- case group.FieldImagePrice4k:
- return m.ImagePrice4k()
- case group.FieldClaudeCodeOnly:
- return m.ClaudeCodeOnly()
- case group.FieldFallbackGroupID:
- return m.FallbackGroupID()
- case group.FieldFallbackGroupIDOnInvalidRequest:
- return m.FallbackGroupIDOnInvalidRequest()
- case group.FieldModelRouting:
- return m.ModelRouting()
- case group.FieldModelRoutingEnabled:
- return m.ModelRoutingEnabled()
- case group.FieldMcpXMLInject:
- return m.McpXMLInject()
- case group.FieldSupportedModelScopes:
- return m.SupportedModelScopes()
- case group.FieldSortOrder:
- return m.SortOrder()
- case group.FieldAllowMessagesDispatch:
- return m.AllowMessagesDispatch()
- case group.FieldRequireOauthOnly:
- return m.RequireOauthOnly()
- case group.FieldRequirePrivacySet:
- return m.RequirePrivacySet()
- case group.FieldDefaultMappedModel:
- return m.DefaultMappedModel()
- case group.FieldMessagesDispatchModelConfig:
- return m.MessagesDispatchModelConfig()
+// ResetDefaultValidityDays resets all changes to the "default_validity_days" field.
+func (m *GroupMutation) ResetDefaultValidityDays() {
+ m.default_validity_days = nil
+ m.adddefault_validity_days = nil
+}
+
+// SetImagePrice1k sets the "image_price_1k" field.
+func (m *GroupMutation) SetImagePrice1k(f float64) {
+ m.image_price_1k = &f
+ m.addimage_price_1k = nil
+}
+
+// ImagePrice1k returns the value of the "image_price_1k" field in the mutation.
+func (m *GroupMutation) ImagePrice1k() (r float64, exists bool) {
+ v := m.image_price_1k
+ if v == nil {
+ return
}
- return nil, false
+ return *v, true
}
-// OldField returns the old value of the field from the database. An error is
-// returned if the mutation operation is not UpdateOne, or the query to the
-// database failed.
-func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
- switch name {
- case group.FieldCreatedAt:
- return m.OldCreatedAt(ctx)
- case group.FieldUpdatedAt:
- return m.OldUpdatedAt(ctx)
- case group.FieldDeletedAt:
- return m.OldDeletedAt(ctx)
- case group.FieldName:
- return m.OldName(ctx)
- case group.FieldDescription:
- return m.OldDescription(ctx)
- case group.FieldRateMultiplier:
- return m.OldRateMultiplier(ctx)
- case group.FieldIsExclusive:
- return m.OldIsExclusive(ctx)
- case group.FieldStatus:
- return m.OldStatus(ctx)
- case group.FieldPlatform:
- return m.OldPlatform(ctx)
- case group.FieldSubscriptionType:
- return m.OldSubscriptionType(ctx)
- case group.FieldDailyLimitUsd:
- return m.OldDailyLimitUsd(ctx)
- case group.FieldWeeklyLimitUsd:
- return m.OldWeeklyLimitUsd(ctx)
- case group.FieldMonthlyLimitUsd:
- return m.OldMonthlyLimitUsd(ctx)
- case group.FieldDefaultValidityDays:
- return m.OldDefaultValidityDays(ctx)
- case group.FieldImagePrice1k:
- return m.OldImagePrice1k(ctx)
- case group.FieldImagePrice2k:
- return m.OldImagePrice2k(ctx)
- case group.FieldImagePrice4k:
- return m.OldImagePrice4k(ctx)
- case group.FieldClaudeCodeOnly:
- return m.OldClaudeCodeOnly(ctx)
- case group.FieldFallbackGroupID:
- return m.OldFallbackGroupID(ctx)
- case group.FieldFallbackGroupIDOnInvalidRequest:
- return m.OldFallbackGroupIDOnInvalidRequest(ctx)
- case group.FieldModelRouting:
- return m.OldModelRouting(ctx)
- case group.FieldModelRoutingEnabled:
- return m.OldModelRoutingEnabled(ctx)
- case group.FieldMcpXMLInject:
- return m.OldMcpXMLInject(ctx)
- case group.FieldSupportedModelScopes:
- return m.OldSupportedModelScopes(ctx)
- case group.FieldSortOrder:
- return m.OldSortOrder(ctx)
- case group.FieldAllowMessagesDispatch:
- return m.OldAllowMessagesDispatch(ctx)
- case group.FieldRequireOauthOnly:
- return m.OldRequireOauthOnly(ctx)
- case group.FieldRequirePrivacySet:
- return m.OldRequirePrivacySet(ctx)
- case group.FieldDefaultMappedModel:
- return m.OldDefaultMappedModel(ctx)
- case group.FieldMessagesDispatchModelConfig:
- return m.OldMessagesDispatchModelConfig(ctx)
+// OldImagePrice1k returns the old "image_price_1k" field's value of the Group entity.
+// If the Group object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *GroupMutation) OldImagePrice1k(ctx context.Context) (v *float64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldImagePrice1k is only allowed on UpdateOne operations")
}
- return nil, fmt.Errorf("unknown Group field %s", name)
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldImagePrice1k requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldImagePrice1k: %w", err)
+ }
+ return oldValue.ImagePrice1k, nil
}
-// SetField sets the value of a field with the given name. It returns an error if
-// the field is not defined in the schema, or if the type mismatched the field
-// type.
-func (m *GroupMutation) SetField(name string, value ent.Value) error {
- switch name {
- case group.FieldCreatedAt:
- v, ok := value.(time.Time)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetCreatedAt(v)
- return nil
- case group.FieldUpdatedAt:
- v, ok := value.(time.Time)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetUpdatedAt(v)
- return nil
- case group.FieldDeletedAt:
- v, ok := value.(time.Time)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetDeletedAt(v)
- return nil
- case group.FieldName:
- v, ok := value.(string)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetName(v)
- return nil
- case group.FieldDescription:
- v, ok := value.(string)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetDescription(v)
- return nil
- case group.FieldRateMultiplier:
- v, ok := value.(float64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetRateMultiplier(v)
- return nil
- case group.FieldIsExclusive:
- v, ok := value.(bool)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetIsExclusive(v)
- return nil
- case group.FieldStatus:
- v, ok := value.(string)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetStatus(v)
- return nil
- case group.FieldPlatform:
- v, ok := value.(string)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetPlatform(v)
- return nil
- case group.FieldSubscriptionType:
- v, ok := value.(string)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetSubscriptionType(v)
- return nil
- case group.FieldDailyLimitUsd:
- v, ok := value.(float64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetDailyLimitUsd(v)
- return nil
- case group.FieldWeeklyLimitUsd:
- v, ok := value.(float64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetWeeklyLimitUsd(v)
- return nil
- case group.FieldMonthlyLimitUsd:
- v, ok := value.(float64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetMonthlyLimitUsd(v)
- return nil
- case group.FieldDefaultValidityDays:
- v, ok := value.(int)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetDefaultValidityDays(v)
- return nil
- case group.FieldImagePrice1k:
- v, ok := value.(float64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetImagePrice1k(v)
- return nil
- case group.FieldImagePrice2k:
- v, ok := value.(float64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetImagePrice2k(v)
- return nil
- case group.FieldImagePrice4k:
- v, ok := value.(float64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetImagePrice4k(v)
- return nil
- case group.FieldClaudeCodeOnly:
- v, ok := value.(bool)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetClaudeCodeOnly(v)
- return nil
- case group.FieldFallbackGroupID:
- v, ok := value.(int64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetFallbackGroupID(v)
- return nil
- case group.FieldFallbackGroupIDOnInvalidRequest:
- v, ok := value.(int64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetFallbackGroupIDOnInvalidRequest(v)
- return nil
- case group.FieldModelRouting:
- v, ok := value.(map[string][]int64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetModelRouting(v)
- return nil
- case group.FieldModelRoutingEnabled:
- v, ok := value.(bool)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetModelRoutingEnabled(v)
- return nil
- case group.FieldMcpXMLInject:
- v, ok := value.(bool)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetMcpXMLInject(v)
- return nil
- case group.FieldSupportedModelScopes:
- v, ok := value.([]string)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetSupportedModelScopes(v)
- return nil
- case group.FieldSortOrder:
- v, ok := value.(int)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetSortOrder(v)
- return nil
- case group.FieldAllowMessagesDispatch:
- v, ok := value.(bool)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetAllowMessagesDispatch(v)
- return nil
- case group.FieldRequireOauthOnly:
- v, ok := value.(bool)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetRequireOauthOnly(v)
- return nil
- case group.FieldRequirePrivacySet:
- v, ok := value.(bool)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetRequirePrivacySet(v)
- return nil
- case group.FieldDefaultMappedModel:
- v, ok := value.(string)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetDefaultMappedModel(v)
- return nil
- case group.FieldMessagesDispatchModelConfig:
- v, ok := value.(domain.OpenAIMessagesDispatchModelConfig)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetMessagesDispatchModelConfig(v)
- return nil
+// AddImagePrice1k adds f to the "image_price_1k" field.
+func (m *GroupMutation) AddImagePrice1k(f float64) {
+ if m.addimage_price_1k != nil {
+ *m.addimage_price_1k += f
+ } else {
+ m.addimage_price_1k = &f
}
- return fmt.Errorf("unknown Group field %s", name)
}
-// AddedFields returns all numeric fields that were incremented/decremented during
-// this mutation.
-func (m *GroupMutation) AddedFields() []string {
- var fields []string
- if m.addrate_multiplier != nil {
- fields = append(fields, group.FieldRateMultiplier)
- }
- if m.adddaily_limit_usd != nil {
- fields = append(fields, group.FieldDailyLimitUsd)
+// AddedImagePrice1k returns the value that was added to the "image_price_1k" field in this mutation.
+func (m *GroupMutation) AddedImagePrice1k() (r float64, exists bool) {
+ v := m.addimage_price_1k
+ if v == nil {
+ return
}
- if m.addweekly_limit_usd != nil {
- fields = append(fields, group.FieldWeeklyLimitUsd)
+ return *v, true
+}
+
+// ClearImagePrice1k clears the value of the "image_price_1k" field.
+func (m *GroupMutation) ClearImagePrice1k() {
+ m.image_price_1k = nil
+ m.addimage_price_1k = nil
+ m.clearedFields[group.FieldImagePrice1k] = struct{}{}
+}
+
+// ImagePrice1kCleared returns if the "image_price_1k" field was cleared in this mutation.
+func (m *GroupMutation) ImagePrice1kCleared() bool {
+ _, ok := m.clearedFields[group.FieldImagePrice1k]
+ return ok
+}
+
+// ResetImagePrice1k resets all changes to the "image_price_1k" field.
+func (m *GroupMutation) ResetImagePrice1k() {
+ m.image_price_1k = nil
+ m.addimage_price_1k = nil
+ delete(m.clearedFields, group.FieldImagePrice1k)
+}
+
+// SetImagePrice2k sets the "image_price_2k" field.
+func (m *GroupMutation) SetImagePrice2k(f float64) {
+ m.image_price_2k = &f
+ m.addimage_price_2k = nil
+}
+
+// ImagePrice2k returns the value of the "image_price_2k" field in the mutation.
+func (m *GroupMutation) ImagePrice2k() (r float64, exists bool) {
+ v := m.image_price_2k
+ if v == nil {
+ return
}
- if m.addmonthly_limit_usd != nil {
- fields = append(fields, group.FieldMonthlyLimitUsd)
+ return *v, true
+}
+
+// OldImagePrice2k returns the old "image_price_2k" field's value of the Group entity.
+// If the Group object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *GroupMutation) OldImagePrice2k(ctx context.Context) (v *float64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldImagePrice2k is only allowed on UpdateOne operations")
}
- if m.adddefault_validity_days != nil {
- fields = append(fields, group.FieldDefaultValidityDays)
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldImagePrice2k requires an ID field in the mutation")
}
- if m.addimage_price_1k != nil {
- fields = append(fields, group.FieldImagePrice1k)
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldImagePrice2k: %w", err)
}
+ return oldValue.ImagePrice2k, nil
+}
+
+// AddImagePrice2k adds f to the "image_price_2k" field.
+func (m *GroupMutation) AddImagePrice2k(f float64) {
if m.addimage_price_2k != nil {
- fields = append(fields, group.FieldImagePrice2k)
+ *m.addimage_price_2k += f
+ } else {
+ m.addimage_price_2k = &f
}
- if m.addimage_price_4k != nil {
- fields = append(fields, group.FieldImagePrice4k)
+}
+
+// AddedImagePrice2k returns the value that was added to the "image_price_2k" field in this mutation.
+func (m *GroupMutation) AddedImagePrice2k() (r float64, exists bool) {
+ v := m.addimage_price_2k
+ if v == nil {
+ return
}
- if m.addfallback_group_id != nil {
- fields = append(fields, group.FieldFallbackGroupID)
+ return *v, true
+}
+
+// ClearImagePrice2k clears the value of the "image_price_2k" field.
+func (m *GroupMutation) ClearImagePrice2k() {
+ m.image_price_2k = nil
+ m.addimage_price_2k = nil
+ m.clearedFields[group.FieldImagePrice2k] = struct{}{}
+}
+
+// ImagePrice2kCleared returns if the "image_price_2k" field was cleared in this mutation.
+func (m *GroupMutation) ImagePrice2kCleared() bool {
+ _, ok := m.clearedFields[group.FieldImagePrice2k]
+ return ok
+}
+
+// ResetImagePrice2k resets all changes to the "image_price_2k" field.
+func (m *GroupMutation) ResetImagePrice2k() {
+ m.image_price_2k = nil
+ m.addimage_price_2k = nil
+ delete(m.clearedFields, group.FieldImagePrice2k)
+}
+
+// SetImagePrice4k sets the "image_price_4k" field.
+func (m *GroupMutation) SetImagePrice4k(f float64) {
+ m.image_price_4k = &f
+ m.addimage_price_4k = nil
+}
+
+// ImagePrice4k returns the value of the "image_price_4k" field in the mutation.
+func (m *GroupMutation) ImagePrice4k() (r float64, exists bool) {
+ v := m.image_price_4k
+ if v == nil {
+ return
}
- if m.addfallback_group_id_on_invalid_request != nil {
- fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest)
+ return *v, true
+}
+
+// OldImagePrice4k returns the old "image_price_4k" field's value of the Group entity.
+// If the Group object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *GroupMutation) OldImagePrice4k(ctx context.Context) (v *float64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldImagePrice4k is only allowed on UpdateOne operations")
}
- if m.addsort_order != nil {
- fields = append(fields, group.FieldSortOrder)
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldImagePrice4k requires an ID field in the mutation")
}
- return fields
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldImagePrice4k: %w", err)
+ }
+ return oldValue.ImagePrice4k, nil
}
-// AddedField returns the numeric value that was incremented/decremented on a field
-// with the given name. The second boolean return value indicates that this field
-// was not set, or was not defined in the schema.
-func (m *GroupMutation) AddedField(name string) (ent.Value, bool) {
- switch name {
- case group.FieldRateMultiplier:
- return m.AddedRateMultiplier()
- case group.FieldDailyLimitUsd:
- return m.AddedDailyLimitUsd()
- case group.FieldWeeklyLimitUsd:
- return m.AddedWeeklyLimitUsd()
- case group.FieldMonthlyLimitUsd:
- return m.AddedMonthlyLimitUsd()
- case group.FieldDefaultValidityDays:
- return m.AddedDefaultValidityDays()
- case group.FieldImagePrice1k:
- return m.AddedImagePrice1k()
- case group.FieldImagePrice2k:
- return m.AddedImagePrice2k()
- case group.FieldImagePrice4k:
- return m.AddedImagePrice4k()
- case group.FieldFallbackGroupID:
- return m.AddedFallbackGroupID()
- case group.FieldFallbackGroupIDOnInvalidRequest:
- return m.AddedFallbackGroupIDOnInvalidRequest()
- case group.FieldSortOrder:
- return m.AddedSortOrder()
+// AddImagePrice4k adds f to the "image_price_4k" field.
+func (m *GroupMutation) AddImagePrice4k(f float64) {
+ if m.addimage_price_4k != nil {
+ *m.addimage_price_4k += f
+ } else {
+ m.addimage_price_4k = &f
}
- return nil, false
}
-// AddField adds the value to the field with the given name. It returns an error if
-// the field is not defined in the schema, or if the type mismatched the field
-// type.
-func (m *GroupMutation) AddField(name string, value ent.Value) error {
- switch name {
- case group.FieldRateMultiplier:
- v, ok := value.(float64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.AddRateMultiplier(v)
- return nil
- case group.FieldDailyLimitUsd:
- v, ok := value.(float64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.AddDailyLimitUsd(v)
- return nil
- case group.FieldWeeklyLimitUsd:
- v, ok := value.(float64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.AddWeeklyLimitUsd(v)
- return nil
- case group.FieldMonthlyLimitUsd:
- v, ok := value.(float64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.AddMonthlyLimitUsd(v)
- return nil
- case group.FieldDefaultValidityDays:
- v, ok := value.(int)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.AddDefaultValidityDays(v)
- return nil
- case group.FieldImagePrice1k:
- v, ok := value.(float64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.AddImagePrice1k(v)
- return nil
- case group.FieldImagePrice2k:
- v, ok := value.(float64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.AddImagePrice2k(v)
- return nil
- case group.FieldImagePrice4k:
- v, ok := value.(float64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.AddImagePrice4k(v)
- return nil
- case group.FieldFallbackGroupID:
- v, ok := value.(int64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.AddFallbackGroupID(v)
- return nil
- case group.FieldFallbackGroupIDOnInvalidRequest:
- v, ok := value.(int64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.AddFallbackGroupIDOnInvalidRequest(v)
- return nil
- case group.FieldSortOrder:
- v, ok := value.(int)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.AddSortOrder(v)
- return nil
+// AddedImagePrice4k returns the value that was added to the "image_price_4k" field in this mutation.
+func (m *GroupMutation) AddedImagePrice4k() (r float64, exists bool) {
+ v := m.addimage_price_4k
+ if v == nil {
+ return
}
- return fmt.Errorf("unknown Group numeric field %s", name)
+ return *v, true
}
-// ClearedFields returns all nullable fields that were cleared during this
-// mutation.
-func (m *GroupMutation) ClearedFields() []string {
- var fields []string
- if m.FieldCleared(group.FieldDeletedAt) {
- fields = append(fields, group.FieldDeletedAt)
- }
- if m.FieldCleared(group.FieldDescription) {
- fields = append(fields, group.FieldDescription)
+// ClearImagePrice4k clears the value of the "image_price_4k" field.
+func (m *GroupMutation) ClearImagePrice4k() {
+ m.image_price_4k = nil
+ m.addimage_price_4k = nil
+ m.clearedFields[group.FieldImagePrice4k] = struct{}{}
+}
+
+// ImagePrice4kCleared returns if the "image_price_4k" field was cleared in this mutation.
+func (m *GroupMutation) ImagePrice4kCleared() bool {
+ _, ok := m.clearedFields[group.FieldImagePrice4k]
+ return ok
+}
+
+// ResetImagePrice4k resets all changes to the "image_price_4k" field.
+func (m *GroupMutation) ResetImagePrice4k() {
+ m.image_price_4k = nil
+ m.addimage_price_4k = nil
+ delete(m.clearedFields, group.FieldImagePrice4k)
+}
+
+// SetClaudeCodeOnly sets the "claude_code_only" field.
+func (m *GroupMutation) SetClaudeCodeOnly(b bool) {
+ m.claude_code_only = &b
+}
+
+// ClaudeCodeOnly returns the value of the "claude_code_only" field in the mutation.
+func (m *GroupMutation) ClaudeCodeOnly() (r bool, exists bool) {
+ v := m.claude_code_only
+ if v == nil {
+ return
}
- if m.FieldCleared(group.FieldDailyLimitUsd) {
- fields = append(fields, group.FieldDailyLimitUsd)
+ return *v, true
+}
+
+// OldClaudeCodeOnly returns the old "claude_code_only" field's value of the Group entity.
+// If the Group object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *GroupMutation) OldClaudeCodeOnly(ctx context.Context) (v bool, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldClaudeCodeOnly is only allowed on UpdateOne operations")
}
- if m.FieldCleared(group.FieldWeeklyLimitUsd) {
- fields = append(fields, group.FieldWeeklyLimitUsd)
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldClaudeCodeOnly requires an ID field in the mutation")
}
- if m.FieldCleared(group.FieldMonthlyLimitUsd) {
- fields = append(fields, group.FieldMonthlyLimitUsd)
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldClaudeCodeOnly: %w", err)
}
- if m.FieldCleared(group.FieldImagePrice1k) {
- fields = append(fields, group.FieldImagePrice1k)
+ return oldValue.ClaudeCodeOnly, nil
+}
+
+// ResetClaudeCodeOnly resets all changes to the "claude_code_only" field.
+func (m *GroupMutation) ResetClaudeCodeOnly() {
+ m.claude_code_only = nil
+}
+
+// SetFallbackGroupID sets the "fallback_group_id" field.
+func (m *GroupMutation) SetFallbackGroupID(i int64) {
+ m.fallback_group_id = &i
+ m.addfallback_group_id = nil
+}
+
+// FallbackGroupID returns the value of the "fallback_group_id" field in the mutation.
+func (m *GroupMutation) FallbackGroupID() (r int64, exists bool) {
+ v := m.fallback_group_id
+ if v == nil {
+ return
}
- if m.FieldCleared(group.FieldImagePrice2k) {
- fields = append(fields, group.FieldImagePrice2k)
+ return *v, true
+}
+
+// OldFallbackGroupID returns the old "fallback_group_id" field's value of the Group entity.
+// If the Group object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *GroupMutation) OldFallbackGroupID(ctx context.Context) (v *int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldFallbackGroupID is only allowed on UpdateOne operations")
}
- if m.FieldCleared(group.FieldImagePrice4k) {
- fields = append(fields, group.FieldImagePrice4k)
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldFallbackGroupID requires an ID field in the mutation")
}
- if m.FieldCleared(group.FieldFallbackGroupID) {
- fields = append(fields, group.FieldFallbackGroupID)
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldFallbackGroupID: %w", err)
}
- if m.FieldCleared(group.FieldFallbackGroupIDOnInvalidRequest) {
- fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest)
+ return oldValue.FallbackGroupID, nil
+}
+
+// AddFallbackGroupID adds i to the "fallback_group_id" field.
+func (m *GroupMutation) AddFallbackGroupID(i int64) {
+ if m.addfallback_group_id != nil {
+ *m.addfallback_group_id += i
+ } else {
+ m.addfallback_group_id = &i
}
- if m.FieldCleared(group.FieldModelRouting) {
- fields = append(fields, group.FieldModelRouting)
+}
+
+// AddedFallbackGroupID returns the value that was added to the "fallback_group_id" field in this mutation.
+func (m *GroupMutation) AddedFallbackGroupID() (r int64, exists bool) {
+ v := m.addfallback_group_id
+ if v == nil {
+ return
}
- return fields
+ return *v, true
}
-// FieldCleared returns a boolean indicating if a field with the given name was
-// cleared in this mutation.
-func (m *GroupMutation) FieldCleared(name string) bool {
- _, ok := m.clearedFields[name]
+// ClearFallbackGroupID clears the value of the "fallback_group_id" field.
+func (m *GroupMutation) ClearFallbackGroupID() {
+ m.fallback_group_id = nil
+ m.addfallback_group_id = nil
+ m.clearedFields[group.FieldFallbackGroupID] = struct{}{}
+}
+
+// FallbackGroupIDCleared returns if the "fallback_group_id" field was cleared in this mutation.
+func (m *GroupMutation) FallbackGroupIDCleared() bool {
+ _, ok := m.clearedFields[group.FieldFallbackGroupID]
return ok
}
-// ClearField clears the value of the field with the given name. It returns an
-// error if the field is not defined in the schema.
-func (m *GroupMutation) ClearField(name string) error {
- switch name {
- case group.FieldDeletedAt:
- m.ClearDeletedAt()
- return nil
- case group.FieldDescription:
- m.ClearDescription()
- return nil
- case group.FieldDailyLimitUsd:
- m.ClearDailyLimitUsd()
- return nil
- case group.FieldWeeklyLimitUsd:
- m.ClearWeeklyLimitUsd()
- return nil
- case group.FieldMonthlyLimitUsd:
- m.ClearMonthlyLimitUsd()
- return nil
- case group.FieldImagePrice1k:
- m.ClearImagePrice1k()
- return nil
- case group.FieldImagePrice2k:
- m.ClearImagePrice2k()
- return nil
- case group.FieldImagePrice4k:
- m.ClearImagePrice4k()
- return nil
- case group.FieldFallbackGroupID:
- m.ClearFallbackGroupID()
- return nil
- case group.FieldFallbackGroupIDOnInvalidRequest:
- m.ClearFallbackGroupIDOnInvalidRequest()
- return nil
- case group.FieldModelRouting:
- m.ClearModelRouting()
- return nil
+// ResetFallbackGroupID resets all changes to the "fallback_group_id" field.
+func (m *GroupMutation) ResetFallbackGroupID() {
+ m.fallback_group_id = nil
+ m.addfallback_group_id = nil
+ delete(m.clearedFields, group.FieldFallbackGroupID)
+}
+
+// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field.
+func (m *GroupMutation) SetFallbackGroupIDOnInvalidRequest(i int64) {
+ m.fallback_group_id_on_invalid_request = &i
+ m.addfallback_group_id_on_invalid_request = nil
+}
+
+// FallbackGroupIDOnInvalidRequest returns the value of the "fallback_group_id_on_invalid_request" field in the mutation.
+func (m *GroupMutation) FallbackGroupIDOnInvalidRequest() (r int64, exists bool) {
+ v := m.fallback_group_id_on_invalid_request
+ if v == nil {
+ return
}
- return fmt.Errorf("unknown Group nullable field %s", name)
+ return *v, true
}
-// ResetField resets all changes in the mutation for the field with the given name.
-// It returns an error if the field is not defined in the schema.
-func (m *GroupMutation) ResetField(name string) error {
- switch name {
- case group.FieldCreatedAt:
- m.ResetCreatedAt()
- return nil
- case group.FieldUpdatedAt:
- m.ResetUpdatedAt()
- return nil
- case group.FieldDeletedAt:
- m.ResetDeletedAt()
- return nil
- case group.FieldName:
- m.ResetName()
- return nil
- case group.FieldDescription:
- m.ResetDescription()
- return nil
- case group.FieldRateMultiplier:
- m.ResetRateMultiplier()
- return nil
- case group.FieldIsExclusive:
- m.ResetIsExclusive()
- return nil
- case group.FieldStatus:
- m.ResetStatus()
- return nil
- case group.FieldPlatform:
- m.ResetPlatform()
- return nil
- case group.FieldSubscriptionType:
- m.ResetSubscriptionType()
- return nil
- case group.FieldDailyLimitUsd:
- m.ResetDailyLimitUsd()
- return nil
- case group.FieldWeeklyLimitUsd:
- m.ResetWeeklyLimitUsd()
- return nil
- case group.FieldMonthlyLimitUsd:
- m.ResetMonthlyLimitUsd()
- return nil
- case group.FieldDefaultValidityDays:
- m.ResetDefaultValidityDays()
- return nil
- case group.FieldImagePrice1k:
- m.ResetImagePrice1k()
- return nil
- case group.FieldImagePrice2k:
- m.ResetImagePrice2k()
- return nil
- case group.FieldImagePrice4k:
- m.ResetImagePrice4k()
- return nil
- case group.FieldClaudeCodeOnly:
- m.ResetClaudeCodeOnly()
- return nil
- case group.FieldFallbackGroupID:
- m.ResetFallbackGroupID()
- return nil
- case group.FieldFallbackGroupIDOnInvalidRequest:
- m.ResetFallbackGroupIDOnInvalidRequest()
- return nil
- case group.FieldModelRouting:
- m.ResetModelRouting()
- return nil
- case group.FieldModelRoutingEnabled:
- m.ResetModelRoutingEnabled()
- return nil
- case group.FieldMcpXMLInject:
- m.ResetMcpXMLInject()
- return nil
- case group.FieldSupportedModelScopes:
- m.ResetSupportedModelScopes()
- return nil
- case group.FieldSortOrder:
- m.ResetSortOrder()
- return nil
- case group.FieldAllowMessagesDispatch:
- m.ResetAllowMessagesDispatch()
- return nil
- case group.FieldRequireOauthOnly:
- m.ResetRequireOauthOnly()
- return nil
- case group.FieldRequirePrivacySet:
- m.ResetRequirePrivacySet()
- return nil
- case group.FieldDefaultMappedModel:
- m.ResetDefaultMappedModel()
- return nil
- case group.FieldMessagesDispatchModelConfig:
- m.ResetMessagesDispatchModelConfig()
- return nil
+// OldFallbackGroupIDOnInvalidRequest returns the old "fallback_group_id_on_invalid_request" field's value of the Group entity.
+// If the Group object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *GroupMutation) OldFallbackGroupIDOnInvalidRequest(ctx context.Context) (v *int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldFallbackGroupIDOnInvalidRequest is only allowed on UpdateOne operations")
}
- return fmt.Errorf("unknown Group field %s", name)
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldFallbackGroupIDOnInvalidRequest requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldFallbackGroupIDOnInvalidRequest: %w", err)
+ }
+ return oldValue.FallbackGroupIDOnInvalidRequest, nil
}
-// AddedEdges returns all edge names that were set/added in this mutation.
-func (m *GroupMutation) AddedEdges() []string {
- edges := make([]string, 0, 6)
- if m.api_keys != nil {
- edges = append(edges, group.EdgeAPIKeys)
+// AddFallbackGroupIDOnInvalidRequest adds i to the "fallback_group_id_on_invalid_request" field.
+func (m *GroupMutation) AddFallbackGroupIDOnInvalidRequest(i int64) {
+ if m.addfallback_group_id_on_invalid_request != nil {
+ *m.addfallback_group_id_on_invalid_request += i
+ } else {
+ m.addfallback_group_id_on_invalid_request = &i
}
- if m.redeem_codes != nil {
- edges = append(edges, group.EdgeRedeemCodes)
+}
+
+// AddedFallbackGroupIDOnInvalidRequest returns the value that was added to the "fallback_group_id_on_invalid_request" field in this mutation.
+func (m *GroupMutation) AddedFallbackGroupIDOnInvalidRequest() (r int64, exists bool) {
+ v := m.addfallback_group_id_on_invalid_request
+ if v == nil {
+ return
}
- if m.subscriptions != nil {
- edges = append(edges, group.EdgeSubscriptions)
+ return *v, true
+}
+
+// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field.
+func (m *GroupMutation) ClearFallbackGroupIDOnInvalidRequest() {
+ m.fallback_group_id_on_invalid_request = nil
+ m.addfallback_group_id_on_invalid_request = nil
+ m.clearedFields[group.FieldFallbackGroupIDOnInvalidRequest] = struct{}{}
+}
+
+// FallbackGroupIDOnInvalidRequestCleared returns if the "fallback_group_id_on_invalid_request" field was cleared in this mutation.
+func (m *GroupMutation) FallbackGroupIDOnInvalidRequestCleared() bool {
+ _, ok := m.clearedFields[group.FieldFallbackGroupIDOnInvalidRequest]
+ return ok
+}
+
+// ResetFallbackGroupIDOnInvalidRequest resets all changes to the "fallback_group_id_on_invalid_request" field.
+func (m *GroupMutation) ResetFallbackGroupIDOnInvalidRequest() {
+ m.fallback_group_id_on_invalid_request = nil
+ m.addfallback_group_id_on_invalid_request = nil
+ delete(m.clearedFields, group.FieldFallbackGroupIDOnInvalidRequest)
+}
+
+// SetModelRouting sets the "model_routing" field.
+func (m *GroupMutation) SetModelRouting(value map[string][]int64) {
+ m.model_routing = &value
+}
+
+// ModelRouting returns the value of the "model_routing" field in the mutation.
+func (m *GroupMutation) ModelRouting() (r map[string][]int64, exists bool) {
+ v := m.model_routing
+ if v == nil {
+ return
}
- if m.usage_logs != nil {
- edges = append(edges, group.EdgeUsageLogs)
+ return *v, true
+}
+
+// OldModelRouting returns the old "model_routing" field's value of the Group entity.
+// If the Group object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *GroupMutation) OldModelRouting(ctx context.Context) (v map[string][]int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldModelRouting is only allowed on UpdateOne operations")
}
- if m.accounts != nil {
- edges = append(edges, group.EdgeAccounts)
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldModelRouting requires an ID field in the mutation")
}
- if m.allowed_users != nil {
- edges = append(edges, group.EdgeAllowedUsers)
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldModelRouting: %w", err)
}
- return edges
+ return oldValue.ModelRouting, nil
}
-// AddedIDs returns all IDs (to other nodes) that were added for the given edge
-// name in this mutation.
-func (m *GroupMutation) AddedIDs(name string) []ent.Value {
- switch name {
- case group.EdgeAPIKeys:
- ids := make([]ent.Value, 0, len(m.api_keys))
- for id := range m.api_keys {
- ids = append(ids, id)
- }
- return ids
- case group.EdgeRedeemCodes:
- ids := make([]ent.Value, 0, len(m.redeem_codes))
- for id := range m.redeem_codes {
- ids = append(ids, id)
- }
- return ids
- case group.EdgeSubscriptions:
- ids := make([]ent.Value, 0, len(m.subscriptions))
- for id := range m.subscriptions {
- ids = append(ids, id)
- }
- return ids
- case group.EdgeUsageLogs:
- ids := make([]ent.Value, 0, len(m.usage_logs))
- for id := range m.usage_logs {
- ids = append(ids, id)
- }
- return ids
- case group.EdgeAccounts:
- ids := make([]ent.Value, 0, len(m.accounts))
- for id := range m.accounts {
- ids = append(ids, id)
- }
- return ids
- case group.EdgeAllowedUsers:
- ids := make([]ent.Value, 0, len(m.allowed_users))
- for id := range m.allowed_users {
- ids = append(ids, id)
- }
- return ids
- }
- return nil
+// ClearModelRouting clears the value of the "model_routing" field.
+func (m *GroupMutation) ClearModelRouting() {
+ m.model_routing = nil
+ m.clearedFields[group.FieldModelRouting] = struct{}{}
}
-// RemovedEdges returns all edge names that were removed in this mutation.
-func (m *GroupMutation) RemovedEdges() []string {
- edges := make([]string, 0, 6)
- if m.removedapi_keys != nil {
- edges = append(edges, group.EdgeAPIKeys)
+// ModelRoutingCleared returns if the "model_routing" field was cleared in this mutation.
+func (m *GroupMutation) ModelRoutingCleared() bool {
+ _, ok := m.clearedFields[group.FieldModelRouting]
+ return ok
+}
+
+// ResetModelRouting resets all changes to the "model_routing" field.
+func (m *GroupMutation) ResetModelRouting() {
+ m.model_routing = nil
+ delete(m.clearedFields, group.FieldModelRouting)
+}
+
+// SetModelRoutingEnabled sets the "model_routing_enabled" field.
+func (m *GroupMutation) SetModelRoutingEnabled(b bool) {
+ m.model_routing_enabled = &b
+}
+
+// ModelRoutingEnabled returns the value of the "model_routing_enabled" field in the mutation.
+func (m *GroupMutation) ModelRoutingEnabled() (r bool, exists bool) {
+ v := m.model_routing_enabled
+ if v == nil {
+ return
}
- if m.removedredeem_codes != nil {
- edges = append(edges, group.EdgeRedeemCodes)
+ return *v, true
+}
+
+// OldModelRoutingEnabled returns the old "model_routing_enabled" field's value of the Group entity.
+// If the Group object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *GroupMutation) OldModelRoutingEnabled(ctx context.Context) (v bool, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldModelRoutingEnabled is only allowed on UpdateOne operations")
}
- if m.removedsubscriptions != nil {
- edges = append(edges, group.EdgeSubscriptions)
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldModelRoutingEnabled requires an ID field in the mutation")
}
- if m.removedusage_logs != nil {
- edges = append(edges, group.EdgeUsageLogs)
- }
- if m.removedaccounts != nil {
- edges = append(edges, group.EdgeAccounts)
- }
- if m.removedallowed_users != nil {
- edges = append(edges, group.EdgeAllowedUsers)
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldModelRoutingEnabled: %w", err)
}
- return edges
+ return oldValue.ModelRoutingEnabled, nil
}
-// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
-// the given name in this mutation.
-func (m *GroupMutation) RemovedIDs(name string) []ent.Value {
- switch name {
- case group.EdgeAPIKeys:
- ids := make([]ent.Value, 0, len(m.removedapi_keys))
- for id := range m.removedapi_keys {
- ids = append(ids, id)
- }
- return ids
- case group.EdgeRedeemCodes:
- ids := make([]ent.Value, 0, len(m.removedredeem_codes))
- for id := range m.removedredeem_codes {
- ids = append(ids, id)
- }
- return ids
- case group.EdgeSubscriptions:
- ids := make([]ent.Value, 0, len(m.removedsubscriptions))
- for id := range m.removedsubscriptions {
- ids = append(ids, id)
- }
- return ids
- case group.EdgeUsageLogs:
- ids := make([]ent.Value, 0, len(m.removedusage_logs))
- for id := range m.removedusage_logs {
- ids = append(ids, id)
- }
- return ids
- case group.EdgeAccounts:
- ids := make([]ent.Value, 0, len(m.removedaccounts))
- for id := range m.removedaccounts {
- ids = append(ids, id)
- }
- return ids
- case group.EdgeAllowedUsers:
- ids := make([]ent.Value, 0, len(m.removedallowed_users))
- for id := range m.removedallowed_users {
- ids = append(ids, id)
- }
- return ids
- }
- return nil
+// ResetModelRoutingEnabled resets all changes to the "model_routing_enabled" field.
+func (m *GroupMutation) ResetModelRoutingEnabled() {
+ m.model_routing_enabled = nil
}
-// ClearedEdges returns all edge names that were cleared in this mutation.
-func (m *GroupMutation) ClearedEdges() []string {
- edges := make([]string, 0, 6)
- if m.clearedapi_keys {
- edges = append(edges, group.EdgeAPIKeys)
- }
- if m.clearedredeem_codes {
- edges = append(edges, group.EdgeRedeemCodes)
- }
- if m.clearedsubscriptions {
- edges = append(edges, group.EdgeSubscriptions)
+// SetMcpXMLInject sets the "mcp_xml_inject" field.
+func (m *GroupMutation) SetMcpXMLInject(b bool) {
+ m.mcp_xml_inject = &b
+}
+
+// McpXMLInject returns the value of the "mcp_xml_inject" field in the mutation.
+func (m *GroupMutation) McpXMLInject() (r bool, exists bool) {
+ v := m.mcp_xml_inject
+ if v == nil {
+ return
}
- if m.clearedusage_logs {
- edges = append(edges, group.EdgeUsageLogs)
+ return *v, true
+}
+
+// OldMcpXMLInject returns the old "mcp_xml_inject" field's value of the Group entity.
+// If the Group object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *GroupMutation) OldMcpXMLInject(ctx context.Context) (v bool, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldMcpXMLInject is only allowed on UpdateOne operations")
}
- if m.clearedaccounts {
- edges = append(edges, group.EdgeAccounts)
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldMcpXMLInject requires an ID field in the mutation")
}
- if m.clearedallowed_users {
- edges = append(edges, group.EdgeAllowedUsers)
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldMcpXMLInject: %w", err)
}
- return edges
+ return oldValue.McpXMLInject, nil
}
-// EdgeCleared returns a boolean which indicates if the edge with the given name
-// was cleared in this mutation.
-func (m *GroupMutation) EdgeCleared(name string) bool {
- switch name {
- case group.EdgeAPIKeys:
- return m.clearedapi_keys
- case group.EdgeRedeemCodes:
- return m.clearedredeem_codes
- case group.EdgeSubscriptions:
- return m.clearedsubscriptions
- case group.EdgeUsageLogs:
- return m.clearedusage_logs
- case group.EdgeAccounts:
- return m.clearedaccounts
- case group.EdgeAllowedUsers:
- return m.clearedallowed_users
- }
- return false
+// ResetMcpXMLInject resets all changes to the "mcp_xml_inject" field.
+func (m *GroupMutation) ResetMcpXMLInject() {
+ m.mcp_xml_inject = nil
}
-// ClearEdge clears the value of the edge with the given name. It returns an error
-// if that edge is not defined in the schema.
-func (m *GroupMutation) ClearEdge(name string) error {
- switch name {
+// SetSupportedModelScopes sets the "supported_model_scopes" field.
+func (m *GroupMutation) SetSupportedModelScopes(s []string) {
+ m.supported_model_scopes = &s
+ m.appendsupported_model_scopes = nil
+}
+
+// SupportedModelScopes returns the value of the "supported_model_scopes" field in the mutation.
+func (m *GroupMutation) SupportedModelScopes() (r []string, exists bool) {
+ v := m.supported_model_scopes
+ if v == nil {
+ return
}
- return fmt.Errorf("unknown Group unique edge %s", name)
+ return *v, true
}
-// ResetEdge resets all changes to the edge with the given name in this mutation.
-// It returns an error if the edge is not defined in the schema.
-func (m *GroupMutation) ResetEdge(name string) error {
- switch name {
- case group.EdgeAPIKeys:
- m.ResetAPIKeys()
- return nil
- case group.EdgeRedeemCodes:
- m.ResetRedeemCodes()
- return nil
- case group.EdgeSubscriptions:
- m.ResetSubscriptions()
- return nil
- case group.EdgeUsageLogs:
- m.ResetUsageLogs()
- return nil
- case group.EdgeAccounts:
- m.ResetAccounts()
- return nil
- case group.EdgeAllowedUsers:
- m.ResetAllowedUsers()
- return nil
+// OldSupportedModelScopes returns the old "supported_model_scopes" field's value of the Group entity.
+// If the Group object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *GroupMutation) OldSupportedModelScopes(ctx context.Context) (v []string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldSupportedModelScopes is only allowed on UpdateOne operations")
}
- return fmt.Errorf("unknown Group edge %s", name)
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldSupportedModelScopes requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldSupportedModelScopes: %w", err)
+ }
+ return oldValue.SupportedModelScopes, nil
}
-// IdempotencyRecordMutation represents an operation that mutates the IdempotencyRecord nodes in the graph.
-type IdempotencyRecordMutation struct {
- config
- op Op
- typ string
- id *int64
- created_at *time.Time
- updated_at *time.Time
- scope *string
- idempotency_key_hash *string
- request_fingerprint *string
- status *string
- response_status *int
- addresponse_status *int
- response_body *string
- error_reason *string
- locked_until *time.Time
- expires_at *time.Time
- clearedFields map[string]struct{}
- done bool
- oldValue func(context.Context) (*IdempotencyRecord, error)
- predicates []predicate.IdempotencyRecord
+// AppendSupportedModelScopes adds s to the "supported_model_scopes" field.
+func (m *GroupMutation) AppendSupportedModelScopes(s []string) {
+ m.appendsupported_model_scopes = append(m.appendsupported_model_scopes, s...)
}
-var _ ent.Mutation = (*IdempotencyRecordMutation)(nil)
+// AppendedSupportedModelScopes returns the list of values that were appended to the "supported_model_scopes" field in this mutation.
+func (m *GroupMutation) AppendedSupportedModelScopes() ([]string, bool) {
+ if len(m.appendsupported_model_scopes) == 0 {
+ return nil, false
+ }
+ return m.appendsupported_model_scopes, true
+}
-// idempotencyrecordOption allows management of the mutation configuration using functional options.
-type idempotencyrecordOption func(*IdempotencyRecordMutation)
+// ResetSupportedModelScopes resets all changes to the "supported_model_scopes" field.
+func (m *GroupMutation) ResetSupportedModelScopes() {
+ m.supported_model_scopes = nil
+ m.appendsupported_model_scopes = nil
+}
-// newIdempotencyRecordMutation creates new mutation for the IdempotencyRecord entity.
-func newIdempotencyRecordMutation(c config, op Op, opts ...idempotencyrecordOption) *IdempotencyRecordMutation {
- m := &IdempotencyRecordMutation{
- config: c,
- op: op,
- typ: TypeIdempotencyRecord,
- clearedFields: make(map[string]struct{}),
- }
- for _, opt := range opts {
- opt(m)
- }
- return m
+// SetSortOrder sets the "sort_order" field.
+func (m *GroupMutation) SetSortOrder(i int) {
+ m.sort_order = &i
+ m.addsort_order = nil
}
-// withIdempotencyRecordID sets the ID field of the mutation.
-func withIdempotencyRecordID(id int64) idempotencyrecordOption {
- return func(m *IdempotencyRecordMutation) {
- var (
- err error
- once sync.Once
- value *IdempotencyRecord
- )
- m.oldValue = func(ctx context.Context) (*IdempotencyRecord, error) {
- once.Do(func() {
- if m.done {
- err = errors.New("querying old values post mutation is not allowed")
- } else {
- value, err = m.Client().IdempotencyRecord.Get(ctx, id)
- }
- })
- return value, err
- }
- m.id = &id
+// SortOrder returns the value of the "sort_order" field in the mutation.
+func (m *GroupMutation) SortOrder() (r int, exists bool) {
+ v := m.sort_order
+ if v == nil {
+ return
}
+ return *v, true
}
-// withIdempotencyRecord sets the old IdempotencyRecord of the mutation.
-func withIdempotencyRecord(node *IdempotencyRecord) idempotencyrecordOption {
- return func(m *IdempotencyRecordMutation) {
- m.oldValue = func(context.Context) (*IdempotencyRecord, error) {
- return node, nil
- }
- m.id = &node.ID
+// OldSortOrder returns the old "sort_order" field's value of the Group entity.
+// If the Group object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *GroupMutation) OldSortOrder(ctx context.Context) (v int, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldSortOrder is only allowed on UpdateOne operations")
}
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldSortOrder requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldSortOrder: %w", err)
+ }
+ return oldValue.SortOrder, nil
}
-// Client returns a new `ent.Client` from the mutation. If the mutation was
-// executed in a transaction (ent.Tx), a transactional client is returned.
-func (m IdempotencyRecordMutation) Client() *Client {
- client := &Client{config: m.config}
- client.init()
- return client
-}
-
-// Tx returns an `ent.Tx` for mutations that were executed in transactions;
-// it returns an error otherwise.
-func (m IdempotencyRecordMutation) Tx() (*Tx, error) {
- if _, ok := m.driver.(*txDriver); !ok {
- return nil, errors.New("ent: mutation is not running in a transaction")
+// AddSortOrder adds i to the "sort_order" field.
+func (m *GroupMutation) AddSortOrder(i int) {
+ if m.addsort_order != nil {
+ *m.addsort_order += i
+ } else {
+ m.addsort_order = &i
}
- tx := &Tx{config: m.config}
- tx.init()
- return tx, nil
}
-// ID returns the ID value in the mutation. Note that the ID is only available
-// if it was provided to the builder or after it was returned from the database.
-func (m *IdempotencyRecordMutation) ID() (id int64, exists bool) {
- if m.id == nil {
+// AddedSortOrder returns the value that was added to the "sort_order" field in this mutation.
+func (m *GroupMutation) AddedSortOrder() (r int, exists bool) {
+ v := m.addsort_order
+ if v == nil {
return
}
- return *m.id, true
+ return *v, true
}
-// IDs queries the database and returns the entity ids that match the mutation's predicate.
-// That means, if the mutation is applied within a transaction with an isolation level such
-// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
-// or updated by the mutation.
-func (m *IdempotencyRecordMutation) IDs(ctx context.Context) ([]int64, error) {
- switch {
- case m.op.Is(OpUpdateOne | OpDeleteOne):
- id, exists := m.ID()
- if exists {
- return []int64{id}, nil
- }
- fallthrough
- case m.op.Is(OpUpdate | OpDelete):
- return m.Client().IdempotencyRecord.Query().Where(m.predicates...).IDs(ctx)
- default:
- return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
- }
+// ResetSortOrder resets all changes to the "sort_order" field.
+func (m *GroupMutation) ResetSortOrder() {
+ m.sort_order = nil
+ m.addsort_order = nil
}
-// SetCreatedAt sets the "created_at" field.
-func (m *IdempotencyRecordMutation) SetCreatedAt(t time.Time) {
- m.created_at = &t
+// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field.
+func (m *GroupMutation) SetAllowMessagesDispatch(b bool) {
+ m.allow_messages_dispatch = &b
}
-// CreatedAt returns the value of the "created_at" field in the mutation.
-func (m *IdempotencyRecordMutation) CreatedAt() (r time.Time, exists bool) {
- v := m.created_at
+// AllowMessagesDispatch returns the value of the "allow_messages_dispatch" field in the mutation.
+func (m *GroupMutation) AllowMessagesDispatch() (r bool, exists bool) {
+ v := m.allow_messages_dispatch
if v == nil {
return
}
return *v, true
}
-// OldCreatedAt returns the old "created_at" field's value of the IdempotencyRecord entity.
-// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database.
+// OldAllowMessagesDispatch returns the old "allow_messages_dispatch" field's value of the Group entity.
+// If the Group object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *IdempotencyRecordMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
+func (m *GroupMutation) OldAllowMessagesDispatch(ctx context.Context) (v bool, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations")
+ return v, errors.New("OldAllowMessagesDispatch is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldCreatedAt requires an ID field in the mutation")
+ return v, errors.New("OldAllowMessagesDispatch requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err)
+ return v, fmt.Errorf("querying old value for OldAllowMessagesDispatch: %w", err)
}
- return oldValue.CreatedAt, nil
+ return oldValue.AllowMessagesDispatch, nil
}
-// ResetCreatedAt resets all changes to the "created_at" field.
-func (m *IdempotencyRecordMutation) ResetCreatedAt() {
- m.created_at = nil
+// ResetAllowMessagesDispatch resets all changes to the "allow_messages_dispatch" field.
+func (m *GroupMutation) ResetAllowMessagesDispatch() {
+ m.allow_messages_dispatch = nil
}
-// SetUpdatedAt sets the "updated_at" field.
-func (m *IdempotencyRecordMutation) SetUpdatedAt(t time.Time) {
- m.updated_at = &t
+// SetRequireOauthOnly sets the "require_oauth_only" field.
+func (m *GroupMutation) SetRequireOauthOnly(b bool) {
+ m.require_oauth_only = &b
}
-// UpdatedAt returns the value of the "updated_at" field in the mutation.
-func (m *IdempotencyRecordMutation) UpdatedAt() (r time.Time, exists bool) {
- v := m.updated_at
+// RequireOauthOnly returns the value of the "require_oauth_only" field in the mutation.
+func (m *GroupMutation) RequireOauthOnly() (r bool, exists bool) {
+ v := m.require_oauth_only
if v == nil {
return
}
return *v, true
}
-// OldUpdatedAt returns the old "updated_at" field's value of the IdempotencyRecord entity.
-// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database.
+// OldRequireOauthOnly returns the old "require_oauth_only" field's value of the Group entity.
+// If the Group object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *IdempotencyRecordMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) {
+func (m *GroupMutation) OldRequireOauthOnly(ctx context.Context) (v bool, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations")
+ return v, errors.New("OldRequireOauthOnly is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldUpdatedAt requires an ID field in the mutation")
+ return v, errors.New("OldRequireOauthOnly requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err)
+ return v, fmt.Errorf("querying old value for OldRequireOauthOnly: %w", err)
}
- return oldValue.UpdatedAt, nil
+ return oldValue.RequireOauthOnly, nil
}
-// ResetUpdatedAt resets all changes to the "updated_at" field.
-func (m *IdempotencyRecordMutation) ResetUpdatedAt() {
- m.updated_at = nil
+// ResetRequireOauthOnly resets all changes to the "require_oauth_only" field.
+func (m *GroupMutation) ResetRequireOauthOnly() {
+ m.require_oauth_only = nil
}
-// SetScope sets the "scope" field.
-func (m *IdempotencyRecordMutation) SetScope(s string) {
- m.scope = &s
+// SetRequirePrivacySet sets the "require_privacy_set" field.
+func (m *GroupMutation) SetRequirePrivacySet(b bool) {
+ m.require_privacy_set = &b
}
-// Scope returns the value of the "scope" field in the mutation.
-func (m *IdempotencyRecordMutation) Scope() (r string, exists bool) {
- v := m.scope
+// RequirePrivacySet returns the value of the "require_privacy_set" field in the mutation.
+func (m *GroupMutation) RequirePrivacySet() (r bool, exists bool) {
+ v := m.require_privacy_set
if v == nil {
return
}
return *v, true
}
-// OldScope returns the old "scope" field's value of the IdempotencyRecord entity.
-// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database.
+// OldRequirePrivacySet returns the old "require_privacy_set" field's value of the Group entity.
+// If the Group object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *IdempotencyRecordMutation) OldScope(ctx context.Context) (v string, err error) {
+func (m *GroupMutation) OldRequirePrivacySet(ctx context.Context) (v bool, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldScope is only allowed on UpdateOne operations")
+ return v, errors.New("OldRequirePrivacySet is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldScope requires an ID field in the mutation")
+ return v, errors.New("OldRequirePrivacySet requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldScope: %w", err)
+ return v, fmt.Errorf("querying old value for OldRequirePrivacySet: %w", err)
}
- return oldValue.Scope, nil
+ return oldValue.RequirePrivacySet, nil
}
-// ResetScope resets all changes to the "scope" field.
-func (m *IdempotencyRecordMutation) ResetScope() {
- m.scope = nil
+// ResetRequirePrivacySet resets all changes to the "require_privacy_set" field.
+func (m *GroupMutation) ResetRequirePrivacySet() {
+ m.require_privacy_set = nil
}
-// SetIdempotencyKeyHash sets the "idempotency_key_hash" field.
-func (m *IdempotencyRecordMutation) SetIdempotencyKeyHash(s string) {
- m.idempotency_key_hash = &s
+// SetDefaultMappedModel sets the "default_mapped_model" field.
+func (m *GroupMutation) SetDefaultMappedModel(s string) {
+ m.default_mapped_model = &s
}
-// IdempotencyKeyHash returns the value of the "idempotency_key_hash" field in the mutation.
-func (m *IdempotencyRecordMutation) IdempotencyKeyHash() (r string, exists bool) {
- v := m.idempotency_key_hash
+// DefaultMappedModel returns the value of the "default_mapped_model" field in the mutation.
+func (m *GroupMutation) DefaultMappedModel() (r string, exists bool) {
+ v := m.default_mapped_model
if v == nil {
return
}
return *v, true
}
-// OldIdempotencyKeyHash returns the old "idempotency_key_hash" field's value of the IdempotencyRecord entity.
-// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database.
+// OldDefaultMappedModel returns the old "default_mapped_model" field's value of the Group entity.
+// If the Group object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *IdempotencyRecordMutation) OldIdempotencyKeyHash(ctx context.Context) (v string, err error) {
+func (m *GroupMutation) OldDefaultMappedModel(ctx context.Context) (v string, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldIdempotencyKeyHash is only allowed on UpdateOne operations")
+ return v, errors.New("OldDefaultMappedModel is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldIdempotencyKeyHash requires an ID field in the mutation")
+ return v, errors.New("OldDefaultMappedModel requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldIdempotencyKeyHash: %w", err)
+ return v, fmt.Errorf("querying old value for OldDefaultMappedModel: %w", err)
}
- return oldValue.IdempotencyKeyHash, nil
+ return oldValue.DefaultMappedModel, nil
}
-// ResetIdempotencyKeyHash resets all changes to the "idempotency_key_hash" field.
-func (m *IdempotencyRecordMutation) ResetIdempotencyKeyHash() {
- m.idempotency_key_hash = nil
+// ResetDefaultMappedModel resets all changes to the "default_mapped_model" field.
+func (m *GroupMutation) ResetDefaultMappedModel() {
+ m.default_mapped_model = nil
}
-// SetRequestFingerprint sets the "request_fingerprint" field.
-func (m *IdempotencyRecordMutation) SetRequestFingerprint(s string) {
- m.request_fingerprint = &s
+// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field.
+func (m *GroupMutation) SetMessagesDispatchModelConfig(damdmc domain.OpenAIMessagesDispatchModelConfig) {
+ m.messages_dispatch_model_config = &damdmc
}
-// RequestFingerprint returns the value of the "request_fingerprint" field in the mutation.
-func (m *IdempotencyRecordMutation) RequestFingerprint() (r string, exists bool) {
- v := m.request_fingerprint
+// MessagesDispatchModelConfig returns the value of the "messages_dispatch_model_config" field in the mutation.
+func (m *GroupMutation) MessagesDispatchModelConfig() (r domain.OpenAIMessagesDispatchModelConfig, exists bool) {
+ v := m.messages_dispatch_model_config
if v == nil {
return
}
return *v, true
}
-// OldRequestFingerprint returns the old "request_fingerprint" field's value of the IdempotencyRecord entity.
-// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database.
+// OldMessagesDispatchModelConfig returns the old "messages_dispatch_model_config" field's value of the Group entity.
+// If the Group object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *IdempotencyRecordMutation) OldRequestFingerprint(ctx context.Context) (v string, err error) {
+func (m *GroupMutation) OldMessagesDispatchModelConfig(ctx context.Context) (v domain.OpenAIMessagesDispatchModelConfig, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldRequestFingerprint is only allowed on UpdateOne operations")
+ return v, errors.New("OldMessagesDispatchModelConfig is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldRequestFingerprint requires an ID field in the mutation")
+ return v, errors.New("OldMessagesDispatchModelConfig requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldRequestFingerprint: %w", err)
+ return v, fmt.Errorf("querying old value for OldMessagesDispatchModelConfig: %w", err)
}
- return oldValue.RequestFingerprint, nil
+ return oldValue.MessagesDispatchModelConfig, nil
}
-// ResetRequestFingerprint resets all changes to the "request_fingerprint" field.
-func (m *IdempotencyRecordMutation) ResetRequestFingerprint() {
- m.request_fingerprint = nil
+// ResetMessagesDispatchModelConfig resets all changes to the "messages_dispatch_model_config" field.
+func (m *GroupMutation) ResetMessagesDispatchModelConfig() {
+ m.messages_dispatch_model_config = nil
}
-// SetStatus sets the "status" field.
-func (m *IdempotencyRecordMutation) SetStatus(s string) {
- m.status = &s
+// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
+func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) {
+ if m.api_keys == nil {
+ m.api_keys = make(map[int64]struct{})
+ }
+ for i := range ids {
+ m.api_keys[ids[i]] = struct{}{}
+ }
}
-// Status returns the value of the "status" field in the mutation.
-func (m *IdempotencyRecordMutation) Status() (r string, exists bool) {
- v := m.status
- if v == nil {
- return
- }
- return *v, true
+// ClearAPIKeys clears the "api_keys" edge to the APIKey entity.
+func (m *GroupMutation) ClearAPIKeys() {
+ m.clearedapi_keys = true
}
-// OldStatus returns the old "status" field's value of the IdempotencyRecord entity.
-// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *IdempotencyRecordMutation) OldStatus(ctx context.Context) (v string, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldStatus is only allowed on UpdateOne operations")
+// APIKeysCleared reports if the "api_keys" edge to the APIKey entity was cleared.
+func (m *GroupMutation) APIKeysCleared() bool {
+ return m.clearedapi_keys
+}
+
+// RemoveAPIKeyIDs removes the "api_keys" edge to the APIKey entity by IDs.
+func (m *GroupMutation) RemoveAPIKeyIDs(ids ...int64) {
+ if m.removedapi_keys == nil {
+ m.removedapi_keys = make(map[int64]struct{})
}
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldStatus requires an ID field in the mutation")
+ for i := range ids {
+ delete(m.api_keys, ids[i])
+ m.removedapi_keys[ids[i]] = struct{}{}
}
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldStatus: %w", err)
+}
+
+// RemovedAPIKeys returns the removed IDs of the "api_keys" edge to the APIKey entity.
+func (m *GroupMutation) RemovedAPIKeysIDs() (ids []int64) {
+ for id := range m.removedapi_keys {
+ ids = append(ids, id)
}
- return oldValue.Status, nil
+ return
}
-// ResetStatus resets all changes to the "status" field.
-func (m *IdempotencyRecordMutation) ResetStatus() {
- m.status = nil
+// APIKeysIDs returns the "api_keys" edge IDs in the mutation.
+func (m *GroupMutation) APIKeysIDs() (ids []int64) {
+ for id := range m.api_keys {
+ ids = append(ids, id)
+ }
+ return
}
-// SetResponseStatus sets the "response_status" field.
-func (m *IdempotencyRecordMutation) SetResponseStatus(i int) {
- m.response_status = &i
- m.addresponse_status = nil
+// ResetAPIKeys resets all changes to the "api_keys" edge.
+func (m *GroupMutation) ResetAPIKeys() {
+ m.api_keys = nil
+ m.clearedapi_keys = false
+ m.removedapi_keys = nil
}
-// ResponseStatus returns the value of the "response_status" field in the mutation.
-func (m *IdempotencyRecordMutation) ResponseStatus() (r int, exists bool) {
- v := m.response_status
- if v == nil {
- return
+// AddRedeemCodeIDs adds the "redeem_codes" edge to the RedeemCode entity by ids.
+func (m *GroupMutation) AddRedeemCodeIDs(ids ...int64) {
+ if m.redeem_codes == nil {
+ m.redeem_codes = make(map[int64]struct{})
+ }
+ for i := range ids {
+ m.redeem_codes[ids[i]] = struct{}{}
}
- return *v, true
}
-// OldResponseStatus returns the old "response_status" field's value of the IdempotencyRecord entity.
-// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *IdempotencyRecordMutation) OldResponseStatus(ctx context.Context) (v *int, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldResponseStatus is only allowed on UpdateOne operations")
- }
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldResponseStatus requires an ID field in the mutation")
+// ClearRedeemCodes clears the "redeem_codes" edge to the RedeemCode entity.
+func (m *GroupMutation) ClearRedeemCodes() {
+ m.clearedredeem_codes = true
+}
+
+// RedeemCodesCleared reports if the "redeem_codes" edge to the RedeemCode entity was cleared.
+func (m *GroupMutation) RedeemCodesCleared() bool {
+ return m.clearedredeem_codes
+}
+
+// RemoveRedeemCodeIDs removes the "redeem_codes" edge to the RedeemCode entity by IDs.
+func (m *GroupMutation) RemoveRedeemCodeIDs(ids ...int64) {
+ if m.removedredeem_codes == nil {
+ m.removedredeem_codes = make(map[int64]struct{})
}
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldResponseStatus: %w", err)
+ for i := range ids {
+ delete(m.redeem_codes, ids[i])
+ m.removedredeem_codes[ids[i]] = struct{}{}
}
- return oldValue.ResponseStatus, nil
}
-// AddResponseStatus adds i to the "response_status" field.
-func (m *IdempotencyRecordMutation) AddResponseStatus(i int) {
- if m.addresponse_status != nil {
- *m.addresponse_status += i
- } else {
- m.addresponse_status = &i
+// RemovedRedeemCodes returns the removed IDs of the "redeem_codes" edge to the RedeemCode entity.
+func (m *GroupMutation) RemovedRedeemCodesIDs() (ids []int64) {
+ for id := range m.removedredeem_codes {
+ ids = append(ids, id)
}
+ return
}
-// AddedResponseStatus returns the value that was added to the "response_status" field in this mutation.
-func (m *IdempotencyRecordMutation) AddedResponseStatus() (r int, exists bool) {
- v := m.addresponse_status
- if v == nil {
- return
+// RedeemCodesIDs returns the "redeem_codes" edge IDs in the mutation.
+func (m *GroupMutation) RedeemCodesIDs() (ids []int64) {
+ for id := range m.redeem_codes {
+ ids = append(ids, id)
}
- return *v, true
+ return
}
-// ClearResponseStatus clears the value of the "response_status" field.
-func (m *IdempotencyRecordMutation) ClearResponseStatus() {
- m.response_status = nil
- m.addresponse_status = nil
- m.clearedFields[idempotencyrecord.FieldResponseStatus] = struct{}{}
+// ResetRedeemCodes resets all changes to the "redeem_codes" edge.
+func (m *GroupMutation) ResetRedeemCodes() {
+ m.redeem_codes = nil
+ m.clearedredeem_codes = false
+ m.removedredeem_codes = nil
}
-// ResponseStatusCleared returns if the "response_status" field was cleared in this mutation.
-func (m *IdempotencyRecordMutation) ResponseStatusCleared() bool {
- _, ok := m.clearedFields[idempotencyrecord.FieldResponseStatus]
- return ok
+// AddSubscriptionIDs adds the "subscriptions" edge to the UserSubscription entity by ids.
+func (m *GroupMutation) AddSubscriptionIDs(ids ...int64) {
+ if m.subscriptions == nil {
+ m.subscriptions = make(map[int64]struct{})
+ }
+ for i := range ids {
+ m.subscriptions[ids[i]] = struct{}{}
+ }
}
-// ResetResponseStatus resets all changes to the "response_status" field.
-func (m *IdempotencyRecordMutation) ResetResponseStatus() {
- m.response_status = nil
- m.addresponse_status = nil
- delete(m.clearedFields, idempotencyrecord.FieldResponseStatus)
+// ClearSubscriptions clears the "subscriptions" edge to the UserSubscription entity.
+func (m *GroupMutation) ClearSubscriptions() {
+ m.clearedsubscriptions = true
}
-// SetResponseBody sets the "response_body" field.
-func (m *IdempotencyRecordMutation) SetResponseBody(s string) {
- m.response_body = &s
+// SubscriptionsCleared reports if the "subscriptions" edge to the UserSubscription entity was cleared.
+func (m *GroupMutation) SubscriptionsCleared() bool {
+ return m.clearedsubscriptions
}
-// ResponseBody returns the value of the "response_body" field in the mutation.
-func (m *IdempotencyRecordMutation) ResponseBody() (r string, exists bool) {
- v := m.response_body
- if v == nil {
- return
+// RemoveSubscriptionIDs removes the "subscriptions" edge to the UserSubscription entity by IDs.
+func (m *GroupMutation) RemoveSubscriptionIDs(ids ...int64) {
+ if m.removedsubscriptions == nil {
+ m.removedsubscriptions = make(map[int64]struct{})
+ }
+ for i := range ids {
+ delete(m.subscriptions, ids[i])
+ m.removedsubscriptions[ids[i]] = struct{}{}
}
- return *v, true
}
-// OldResponseBody returns the old "response_body" field's value of the IdempotencyRecord entity.
-// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *IdempotencyRecordMutation) OldResponseBody(ctx context.Context) (v *string, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldResponseBody is only allowed on UpdateOne operations")
- }
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldResponseBody requires an ID field in the mutation")
- }
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldResponseBody: %w", err)
+// RemovedSubscriptions returns the removed IDs of the "subscriptions" edge to the UserSubscription entity.
+func (m *GroupMutation) RemovedSubscriptionsIDs() (ids []int64) {
+ for id := range m.removedsubscriptions {
+ ids = append(ids, id)
}
- return oldValue.ResponseBody, nil
+ return
}
-// ClearResponseBody clears the value of the "response_body" field.
-func (m *IdempotencyRecordMutation) ClearResponseBody() {
- m.response_body = nil
- m.clearedFields[idempotencyrecord.FieldResponseBody] = struct{}{}
+// SubscriptionsIDs returns the "subscriptions" edge IDs in the mutation.
+func (m *GroupMutation) SubscriptionsIDs() (ids []int64) {
+ for id := range m.subscriptions {
+ ids = append(ids, id)
+ }
+ return
}
-// ResponseBodyCleared returns if the "response_body" field was cleared in this mutation.
-func (m *IdempotencyRecordMutation) ResponseBodyCleared() bool {
- _, ok := m.clearedFields[idempotencyrecord.FieldResponseBody]
- return ok
+// ResetSubscriptions resets all changes to the "subscriptions" edge.
+func (m *GroupMutation) ResetSubscriptions() {
+ m.subscriptions = nil
+ m.clearedsubscriptions = false
+ m.removedsubscriptions = nil
}
-// ResetResponseBody resets all changes to the "response_body" field.
-func (m *IdempotencyRecordMutation) ResetResponseBody() {
- m.response_body = nil
- delete(m.clearedFields, idempotencyrecord.FieldResponseBody)
+// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by ids.
+func (m *GroupMutation) AddUsageLogIDs(ids ...int64) {
+ if m.usage_logs == nil {
+ m.usage_logs = make(map[int64]struct{})
+ }
+ for i := range ids {
+ m.usage_logs[ids[i]] = struct{}{}
+ }
}
-// SetErrorReason sets the "error_reason" field.
-func (m *IdempotencyRecordMutation) SetErrorReason(s string) {
- m.error_reason = &s
+// ClearUsageLogs clears the "usage_logs" edge to the UsageLog entity.
+func (m *GroupMutation) ClearUsageLogs() {
+ m.clearedusage_logs = true
}
-// ErrorReason returns the value of the "error_reason" field in the mutation.
-func (m *IdempotencyRecordMutation) ErrorReason() (r string, exists bool) {
- v := m.error_reason
- if v == nil {
- return
- }
- return *v, true
+// UsageLogsCleared reports if the "usage_logs" edge to the UsageLog entity was cleared.
+func (m *GroupMutation) UsageLogsCleared() bool {
+ return m.clearedusage_logs
}
-// OldErrorReason returns the old "error_reason" field's value of the IdempotencyRecord entity.
-// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *IdempotencyRecordMutation) OldErrorReason(ctx context.Context) (v *string, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldErrorReason is only allowed on UpdateOne operations")
+// RemoveUsageLogIDs removes the "usage_logs" edge to the UsageLog entity by IDs.
+func (m *GroupMutation) RemoveUsageLogIDs(ids ...int64) {
+ if m.removedusage_logs == nil {
+ m.removedusage_logs = make(map[int64]struct{})
}
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldErrorReason requires an ID field in the mutation")
+ for i := range ids {
+ delete(m.usage_logs, ids[i])
+ m.removedusage_logs[ids[i]] = struct{}{}
}
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldErrorReason: %w", err)
+}
+
+// RemovedUsageLogs returns the removed IDs of the "usage_logs" edge to the UsageLog entity.
+func (m *GroupMutation) RemovedUsageLogsIDs() (ids []int64) {
+ for id := range m.removedusage_logs {
+ ids = append(ids, id)
}
- return oldValue.ErrorReason, nil
+ return
}
-// ClearErrorReason clears the value of the "error_reason" field.
-func (m *IdempotencyRecordMutation) ClearErrorReason() {
- m.error_reason = nil
- m.clearedFields[idempotencyrecord.FieldErrorReason] = struct{}{}
+// UsageLogsIDs returns the "usage_logs" edge IDs in the mutation.
+func (m *GroupMutation) UsageLogsIDs() (ids []int64) {
+ for id := range m.usage_logs {
+ ids = append(ids, id)
+ }
+ return
}
-// ErrorReasonCleared returns if the "error_reason" field was cleared in this mutation.
-func (m *IdempotencyRecordMutation) ErrorReasonCleared() bool {
- _, ok := m.clearedFields[idempotencyrecord.FieldErrorReason]
- return ok
+// ResetUsageLogs resets all changes to the "usage_logs" edge.
+func (m *GroupMutation) ResetUsageLogs() {
+ m.usage_logs = nil
+ m.clearedusage_logs = false
+ m.removedusage_logs = nil
}
-// ResetErrorReason resets all changes to the "error_reason" field.
-func (m *IdempotencyRecordMutation) ResetErrorReason() {
- m.error_reason = nil
- delete(m.clearedFields, idempotencyrecord.FieldErrorReason)
+// AddAccountIDs adds the "accounts" edge to the Account entity by ids.
+func (m *GroupMutation) AddAccountIDs(ids ...int64) {
+ if m.accounts == nil {
+ m.accounts = make(map[int64]struct{})
+ }
+ for i := range ids {
+ m.accounts[ids[i]] = struct{}{}
+ }
}
-// SetLockedUntil sets the "locked_until" field.
-func (m *IdempotencyRecordMutation) SetLockedUntil(t time.Time) {
- m.locked_until = &t
+// ClearAccounts clears the "accounts" edge to the Account entity.
+func (m *GroupMutation) ClearAccounts() {
+ m.clearedaccounts = true
}
-// LockedUntil returns the value of the "locked_until" field in the mutation.
-func (m *IdempotencyRecordMutation) LockedUntil() (r time.Time, exists bool) {
- v := m.locked_until
- if v == nil {
- return
- }
- return *v, true
+// AccountsCleared reports if the "accounts" edge to the Account entity was cleared.
+func (m *GroupMutation) AccountsCleared() bool {
+ return m.clearedaccounts
}
-// OldLockedUntil returns the old "locked_until" field's value of the IdempotencyRecord entity.
-// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *IdempotencyRecordMutation) OldLockedUntil(ctx context.Context) (v *time.Time, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldLockedUntil is only allowed on UpdateOne operations")
+// RemoveAccountIDs removes the "accounts" edge to the Account entity by IDs.
+func (m *GroupMutation) RemoveAccountIDs(ids ...int64) {
+ if m.removedaccounts == nil {
+ m.removedaccounts = make(map[int64]struct{})
}
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldLockedUntil requires an ID field in the mutation")
+ for i := range ids {
+ delete(m.accounts, ids[i])
+ m.removedaccounts[ids[i]] = struct{}{}
}
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldLockedUntil: %w", err)
+}
+
+// RemovedAccounts returns the removed IDs of the "accounts" edge to the Account entity.
+func (m *GroupMutation) RemovedAccountsIDs() (ids []int64) {
+ for id := range m.removedaccounts {
+ ids = append(ids, id)
}
- return oldValue.LockedUntil, nil
+ return
}
-// ClearLockedUntil clears the value of the "locked_until" field.
-func (m *IdempotencyRecordMutation) ClearLockedUntil() {
- m.locked_until = nil
- m.clearedFields[idempotencyrecord.FieldLockedUntil] = struct{}{}
+// AccountsIDs returns the "accounts" edge IDs in the mutation.
+func (m *GroupMutation) AccountsIDs() (ids []int64) {
+ for id := range m.accounts {
+ ids = append(ids, id)
+ }
+ return
}
-// LockedUntilCleared returns if the "locked_until" field was cleared in this mutation.
-func (m *IdempotencyRecordMutation) LockedUntilCleared() bool {
- _, ok := m.clearedFields[idempotencyrecord.FieldLockedUntil]
- return ok
+// ResetAccounts resets all changes to the "accounts" edge.
+func (m *GroupMutation) ResetAccounts() {
+ m.accounts = nil
+ m.clearedaccounts = false
+ m.removedaccounts = nil
}
-// ResetLockedUntil resets all changes to the "locked_until" field.
-func (m *IdempotencyRecordMutation) ResetLockedUntil() {
- m.locked_until = nil
- delete(m.clearedFields, idempotencyrecord.FieldLockedUntil)
+// AddAllowedUserIDs adds the "allowed_users" edge to the User entity by ids.
+func (m *GroupMutation) AddAllowedUserIDs(ids ...int64) {
+ if m.allowed_users == nil {
+ m.allowed_users = make(map[int64]struct{})
+ }
+ for i := range ids {
+ m.allowed_users[ids[i]] = struct{}{}
+ }
}
-// SetExpiresAt sets the "expires_at" field.
-func (m *IdempotencyRecordMutation) SetExpiresAt(t time.Time) {
- m.expires_at = &t
+// ClearAllowedUsers clears the "allowed_users" edge to the User entity.
+func (m *GroupMutation) ClearAllowedUsers() {
+ m.clearedallowed_users = true
}
-// ExpiresAt returns the value of the "expires_at" field in the mutation.
-func (m *IdempotencyRecordMutation) ExpiresAt() (r time.Time, exists bool) {
- v := m.expires_at
- if v == nil {
- return
- }
- return *v, true
+// AllowedUsersCleared reports if the "allowed_users" edge to the User entity was cleared.
+func (m *GroupMutation) AllowedUsersCleared() bool {
+ return m.clearedallowed_users
}
-// OldExpiresAt returns the old "expires_at" field's value of the IdempotencyRecord entity.
-// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *IdempotencyRecordMutation) OldExpiresAt(ctx context.Context) (v time.Time, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldExpiresAt is only allowed on UpdateOne operations")
+// RemoveAllowedUserIDs removes the "allowed_users" edge to the User entity by IDs.
+func (m *GroupMutation) RemoveAllowedUserIDs(ids ...int64) {
+ if m.removedallowed_users == nil {
+ m.removedallowed_users = make(map[int64]struct{})
}
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldExpiresAt requires an ID field in the mutation")
+ for i := range ids {
+ delete(m.allowed_users, ids[i])
+ m.removedallowed_users[ids[i]] = struct{}{}
}
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldExpiresAt: %w", err)
+}
+
+// RemovedAllowedUsers returns the removed IDs of the "allowed_users" edge to the User entity.
+func (m *GroupMutation) RemovedAllowedUsersIDs() (ids []int64) {
+ for id := range m.removedallowed_users {
+ ids = append(ids, id)
}
- return oldValue.ExpiresAt, nil
+ return
}
-// ResetExpiresAt resets all changes to the "expires_at" field.
-func (m *IdempotencyRecordMutation) ResetExpiresAt() {
- m.expires_at = nil
+// AllowedUsersIDs returns the "allowed_users" edge IDs in the mutation.
+func (m *GroupMutation) AllowedUsersIDs() (ids []int64) {
+ for id := range m.allowed_users {
+ ids = append(ids, id)
+ }
+ return
}
-// Where appends a list predicates to the IdempotencyRecordMutation builder.
-func (m *IdempotencyRecordMutation) Where(ps ...predicate.IdempotencyRecord) {
+// ResetAllowedUsers resets all changes to the "allowed_users" edge.
+func (m *GroupMutation) ResetAllowedUsers() {
+ m.allowed_users = nil
+ m.clearedallowed_users = false
+ m.removedallowed_users = nil
+}
+
+// Where appends a list predicates to the GroupMutation builder.
+func (m *GroupMutation) Where(ps ...predicate.Group) {
m.predicates = append(m.predicates, ps...)
}
-// WhereP appends storage-level predicates to the IdempotencyRecordMutation builder. Using this method,
+// WhereP appends storage-level predicates to the GroupMutation builder. Using this method,
// users can use type-assertion to append predicates that do not depend on any generated package.
-func (m *IdempotencyRecordMutation) WhereP(ps ...func(*sql.Selector)) {
- p := make([]predicate.IdempotencyRecord, len(ps))
+func (m *GroupMutation) WhereP(ps ...func(*sql.Selector)) {
+ p := make([]predicate.Group, len(ps))
for i := range ps {
p[i] = ps[i]
}
@@ -11816,215 +12030,511 @@ func (m *IdempotencyRecordMutation) WhereP(ps ...func(*sql.Selector)) {
}
// Op returns the operation name.
-func (m *IdempotencyRecordMutation) Op() Op {
+func (m *GroupMutation) Op() Op {
return m.op
}
// SetOp allows setting the mutation operation.
-func (m *IdempotencyRecordMutation) SetOp(op Op) {
+func (m *GroupMutation) SetOp(op Op) {
m.op = op
}
-// Type returns the node type of this mutation (IdempotencyRecord).
-func (m *IdempotencyRecordMutation) Type() string {
+// Type returns the node type of this mutation (Group).
+func (m *GroupMutation) Type() string {
return m.typ
}
// Fields returns all fields that were changed during this mutation. Note that in
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
-func (m *IdempotencyRecordMutation) Fields() []string {
- fields := make([]string, 0, 11)
+func (m *GroupMutation) Fields() []string {
+ fields := make([]string, 0, 30)
if m.created_at != nil {
- fields = append(fields, idempotencyrecord.FieldCreatedAt)
+ fields = append(fields, group.FieldCreatedAt)
}
if m.updated_at != nil {
- fields = append(fields, idempotencyrecord.FieldUpdatedAt)
+ fields = append(fields, group.FieldUpdatedAt)
}
- if m.scope != nil {
- fields = append(fields, idempotencyrecord.FieldScope)
+ if m.deleted_at != nil {
+ fields = append(fields, group.FieldDeletedAt)
}
- if m.idempotency_key_hash != nil {
- fields = append(fields, idempotencyrecord.FieldIdempotencyKeyHash)
+ if m.name != nil {
+ fields = append(fields, group.FieldName)
}
- if m.request_fingerprint != nil {
- fields = append(fields, idempotencyrecord.FieldRequestFingerprint)
+ if m.description != nil {
+ fields = append(fields, group.FieldDescription)
}
- if m.status != nil {
- fields = append(fields, idempotencyrecord.FieldStatus)
+ if m.rate_multiplier != nil {
+ fields = append(fields, group.FieldRateMultiplier)
}
- if m.response_status != nil {
- fields = append(fields, idempotencyrecord.FieldResponseStatus)
+ if m.is_exclusive != nil {
+ fields = append(fields, group.FieldIsExclusive)
}
- if m.response_body != nil {
- fields = append(fields, idempotencyrecord.FieldResponseBody)
+ if m.status != nil {
+ fields = append(fields, group.FieldStatus)
}
- if m.error_reason != nil {
- fields = append(fields, idempotencyrecord.FieldErrorReason)
+ if m.platform != nil {
+ fields = append(fields, group.FieldPlatform)
}
- if m.locked_until != nil {
- fields = append(fields, idempotencyrecord.FieldLockedUntil)
+ if m.subscription_type != nil {
+ fields = append(fields, group.FieldSubscriptionType)
}
- if m.expires_at != nil {
- fields = append(fields, idempotencyrecord.FieldExpiresAt)
+ if m.daily_limit_usd != nil {
+ fields = append(fields, group.FieldDailyLimitUsd)
}
- return fields
-}
-
-// Field returns the value of a field with the given name. The second boolean
-// return value indicates that this field was not set, or was not defined in the
-// schema.
-func (m *IdempotencyRecordMutation) Field(name string) (ent.Value, bool) {
- switch name {
- case idempotencyrecord.FieldCreatedAt:
- return m.CreatedAt()
- case idempotencyrecord.FieldUpdatedAt:
- return m.UpdatedAt()
- case idempotencyrecord.FieldScope:
- return m.Scope()
- case idempotencyrecord.FieldIdempotencyKeyHash:
- return m.IdempotencyKeyHash()
- case idempotencyrecord.FieldRequestFingerprint:
- return m.RequestFingerprint()
- case idempotencyrecord.FieldStatus:
- return m.Status()
- case idempotencyrecord.FieldResponseStatus:
- return m.ResponseStatus()
- case idempotencyrecord.FieldResponseBody:
- return m.ResponseBody()
- case idempotencyrecord.FieldErrorReason:
- return m.ErrorReason()
- case idempotencyrecord.FieldLockedUntil:
- return m.LockedUntil()
- case idempotencyrecord.FieldExpiresAt:
- return m.ExpiresAt()
+ if m.weekly_limit_usd != nil {
+ fields = append(fields, group.FieldWeeklyLimitUsd)
}
- return nil, false
+ if m.monthly_limit_usd != nil {
+ fields = append(fields, group.FieldMonthlyLimitUsd)
+ }
+ if m.default_validity_days != nil {
+ fields = append(fields, group.FieldDefaultValidityDays)
+ }
+ if m.image_price_1k != nil {
+ fields = append(fields, group.FieldImagePrice1k)
+ }
+ if m.image_price_2k != nil {
+ fields = append(fields, group.FieldImagePrice2k)
+ }
+ if m.image_price_4k != nil {
+ fields = append(fields, group.FieldImagePrice4k)
+ }
+ if m.claude_code_only != nil {
+ fields = append(fields, group.FieldClaudeCodeOnly)
+ }
+ if m.fallback_group_id != nil {
+ fields = append(fields, group.FieldFallbackGroupID)
+ }
+ if m.fallback_group_id_on_invalid_request != nil {
+ fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest)
+ }
+ if m.model_routing != nil {
+ fields = append(fields, group.FieldModelRouting)
+ }
+ if m.model_routing_enabled != nil {
+ fields = append(fields, group.FieldModelRoutingEnabled)
+ }
+ if m.mcp_xml_inject != nil {
+ fields = append(fields, group.FieldMcpXMLInject)
+ }
+ if m.supported_model_scopes != nil {
+ fields = append(fields, group.FieldSupportedModelScopes)
+ }
+ if m.sort_order != nil {
+ fields = append(fields, group.FieldSortOrder)
+ }
+ if m.allow_messages_dispatch != nil {
+ fields = append(fields, group.FieldAllowMessagesDispatch)
+ }
+ if m.require_oauth_only != nil {
+ fields = append(fields, group.FieldRequireOauthOnly)
+ }
+ if m.require_privacy_set != nil {
+ fields = append(fields, group.FieldRequirePrivacySet)
+ }
+ if m.default_mapped_model != nil {
+ fields = append(fields, group.FieldDefaultMappedModel)
+ }
+ if m.messages_dispatch_model_config != nil {
+ fields = append(fields, group.FieldMessagesDispatchModelConfig)
+ }
+ return fields
+}
+
+// Field returns the value of a field with the given name. The second boolean
+// return value indicates that this field was not set, or was not defined in the
+// schema.
+func (m *GroupMutation) Field(name string) (ent.Value, bool) {
+ switch name {
+ case group.FieldCreatedAt:
+ return m.CreatedAt()
+ case group.FieldUpdatedAt:
+ return m.UpdatedAt()
+ case group.FieldDeletedAt:
+ return m.DeletedAt()
+ case group.FieldName:
+ return m.Name()
+ case group.FieldDescription:
+ return m.Description()
+ case group.FieldRateMultiplier:
+ return m.RateMultiplier()
+ case group.FieldIsExclusive:
+ return m.IsExclusive()
+ case group.FieldStatus:
+ return m.Status()
+ case group.FieldPlatform:
+ return m.Platform()
+ case group.FieldSubscriptionType:
+ return m.SubscriptionType()
+ case group.FieldDailyLimitUsd:
+ return m.DailyLimitUsd()
+ case group.FieldWeeklyLimitUsd:
+ return m.WeeklyLimitUsd()
+ case group.FieldMonthlyLimitUsd:
+ return m.MonthlyLimitUsd()
+ case group.FieldDefaultValidityDays:
+ return m.DefaultValidityDays()
+ case group.FieldImagePrice1k:
+ return m.ImagePrice1k()
+ case group.FieldImagePrice2k:
+ return m.ImagePrice2k()
+ case group.FieldImagePrice4k:
+ return m.ImagePrice4k()
+ case group.FieldClaudeCodeOnly:
+ return m.ClaudeCodeOnly()
+ case group.FieldFallbackGroupID:
+ return m.FallbackGroupID()
+ case group.FieldFallbackGroupIDOnInvalidRequest:
+ return m.FallbackGroupIDOnInvalidRequest()
+ case group.FieldModelRouting:
+ return m.ModelRouting()
+ case group.FieldModelRoutingEnabled:
+ return m.ModelRoutingEnabled()
+ case group.FieldMcpXMLInject:
+ return m.McpXMLInject()
+ case group.FieldSupportedModelScopes:
+ return m.SupportedModelScopes()
+ case group.FieldSortOrder:
+ return m.SortOrder()
+ case group.FieldAllowMessagesDispatch:
+ return m.AllowMessagesDispatch()
+ case group.FieldRequireOauthOnly:
+ return m.RequireOauthOnly()
+ case group.FieldRequirePrivacySet:
+ return m.RequirePrivacySet()
+ case group.FieldDefaultMappedModel:
+ return m.DefaultMappedModel()
+ case group.FieldMessagesDispatchModelConfig:
+ return m.MessagesDispatchModelConfig()
+ }
+ return nil, false
}
// OldField returns the old value of the field from the database. An error is
// returned if the mutation operation is not UpdateOne, or the query to the
// database failed.
-func (m *IdempotencyRecordMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
+func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
switch name {
- case idempotencyrecord.FieldCreatedAt:
+ case group.FieldCreatedAt:
return m.OldCreatedAt(ctx)
- case idempotencyrecord.FieldUpdatedAt:
+ case group.FieldUpdatedAt:
return m.OldUpdatedAt(ctx)
- case idempotencyrecord.FieldScope:
- return m.OldScope(ctx)
- case idempotencyrecord.FieldIdempotencyKeyHash:
- return m.OldIdempotencyKeyHash(ctx)
- case idempotencyrecord.FieldRequestFingerprint:
- return m.OldRequestFingerprint(ctx)
- case idempotencyrecord.FieldStatus:
+ case group.FieldDeletedAt:
+ return m.OldDeletedAt(ctx)
+ case group.FieldName:
+ return m.OldName(ctx)
+ case group.FieldDescription:
+ return m.OldDescription(ctx)
+ case group.FieldRateMultiplier:
+ return m.OldRateMultiplier(ctx)
+ case group.FieldIsExclusive:
+ return m.OldIsExclusive(ctx)
+ case group.FieldStatus:
return m.OldStatus(ctx)
- case idempotencyrecord.FieldResponseStatus:
- return m.OldResponseStatus(ctx)
- case idempotencyrecord.FieldResponseBody:
- return m.OldResponseBody(ctx)
- case idempotencyrecord.FieldErrorReason:
- return m.OldErrorReason(ctx)
- case idempotencyrecord.FieldLockedUntil:
- return m.OldLockedUntil(ctx)
- case idempotencyrecord.FieldExpiresAt:
- return m.OldExpiresAt(ctx)
+ case group.FieldPlatform:
+ return m.OldPlatform(ctx)
+ case group.FieldSubscriptionType:
+ return m.OldSubscriptionType(ctx)
+ case group.FieldDailyLimitUsd:
+ return m.OldDailyLimitUsd(ctx)
+ case group.FieldWeeklyLimitUsd:
+ return m.OldWeeklyLimitUsd(ctx)
+ case group.FieldMonthlyLimitUsd:
+ return m.OldMonthlyLimitUsd(ctx)
+ case group.FieldDefaultValidityDays:
+ return m.OldDefaultValidityDays(ctx)
+ case group.FieldImagePrice1k:
+ return m.OldImagePrice1k(ctx)
+ case group.FieldImagePrice2k:
+ return m.OldImagePrice2k(ctx)
+ case group.FieldImagePrice4k:
+ return m.OldImagePrice4k(ctx)
+ case group.FieldClaudeCodeOnly:
+ return m.OldClaudeCodeOnly(ctx)
+ case group.FieldFallbackGroupID:
+ return m.OldFallbackGroupID(ctx)
+ case group.FieldFallbackGroupIDOnInvalidRequest:
+ return m.OldFallbackGroupIDOnInvalidRequest(ctx)
+ case group.FieldModelRouting:
+ return m.OldModelRouting(ctx)
+ case group.FieldModelRoutingEnabled:
+ return m.OldModelRoutingEnabled(ctx)
+ case group.FieldMcpXMLInject:
+ return m.OldMcpXMLInject(ctx)
+ case group.FieldSupportedModelScopes:
+ return m.OldSupportedModelScopes(ctx)
+ case group.FieldSortOrder:
+ return m.OldSortOrder(ctx)
+ case group.FieldAllowMessagesDispatch:
+ return m.OldAllowMessagesDispatch(ctx)
+ case group.FieldRequireOauthOnly:
+ return m.OldRequireOauthOnly(ctx)
+ case group.FieldRequirePrivacySet:
+ return m.OldRequirePrivacySet(ctx)
+ case group.FieldDefaultMappedModel:
+ return m.OldDefaultMappedModel(ctx)
+ case group.FieldMessagesDispatchModelConfig:
+ return m.OldMessagesDispatchModelConfig(ctx)
}
- return nil, fmt.Errorf("unknown IdempotencyRecord field %s", name)
+ return nil, fmt.Errorf("unknown Group field %s", name)
}
// SetField sets the value of a field with the given name. It returns an error if
// the field is not defined in the schema, or if the type mismatched the field
// type.
-func (m *IdempotencyRecordMutation) SetField(name string, value ent.Value) error {
+func (m *GroupMutation) SetField(name string, value ent.Value) error {
switch name {
- case idempotencyrecord.FieldCreatedAt:
+ case group.FieldCreatedAt:
v, ok := value.(time.Time)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetCreatedAt(v)
return nil
- case idempotencyrecord.FieldUpdatedAt:
+ case group.FieldUpdatedAt:
v, ok := value.(time.Time)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetUpdatedAt(v)
return nil
- case idempotencyrecord.FieldScope:
- v, ok := value.(string)
+ case group.FieldDeletedAt:
+ v, ok := value.(time.Time)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
- m.SetScope(v)
+ m.SetDeletedAt(v)
return nil
- case idempotencyrecord.FieldIdempotencyKeyHash:
+ case group.FieldName:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
- m.SetIdempotencyKeyHash(v)
+ m.SetName(v)
return nil
- case idempotencyrecord.FieldRequestFingerprint:
+ case group.FieldDescription:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
- m.SetRequestFingerprint(v)
+ m.SetDescription(v)
return nil
- case idempotencyrecord.FieldStatus:
- v, ok := value.(string)
+ case group.FieldRateMultiplier:
+ v, ok := value.(float64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
- m.SetStatus(v)
+ m.SetRateMultiplier(v)
return nil
- case idempotencyrecord.FieldResponseStatus:
- v, ok := value.(int)
+ case group.FieldIsExclusive:
+ v, ok := value.(bool)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
- m.SetResponseStatus(v)
+ m.SetIsExclusive(v)
return nil
- case idempotencyrecord.FieldResponseBody:
+ case group.FieldStatus:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
- m.SetResponseBody(v)
+ m.SetStatus(v)
return nil
- case idempotencyrecord.FieldErrorReason:
+ case group.FieldPlatform:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
- m.SetErrorReason(v)
+ m.SetPlatform(v)
return nil
- case idempotencyrecord.FieldLockedUntil:
- v, ok := value.(time.Time)
+ case group.FieldSubscriptionType:
+ v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
- m.SetLockedUntil(v)
+ m.SetSubscriptionType(v)
return nil
- case idempotencyrecord.FieldExpiresAt:
- v, ok := value.(time.Time)
+ case group.FieldDailyLimitUsd:
+ v, ok := value.(float64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
- m.SetExpiresAt(v)
+ m.SetDailyLimitUsd(v)
+ return nil
+ case group.FieldWeeklyLimitUsd:
+ v, ok := value.(float64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetWeeklyLimitUsd(v)
+ return nil
+ case group.FieldMonthlyLimitUsd:
+ v, ok := value.(float64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetMonthlyLimitUsd(v)
+ return nil
+ case group.FieldDefaultValidityDays:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetDefaultValidityDays(v)
+ return nil
+ case group.FieldImagePrice1k:
+ v, ok := value.(float64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetImagePrice1k(v)
+ return nil
+ case group.FieldImagePrice2k:
+ v, ok := value.(float64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetImagePrice2k(v)
+ return nil
+ case group.FieldImagePrice4k:
+ v, ok := value.(float64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetImagePrice4k(v)
+ return nil
+ case group.FieldClaudeCodeOnly:
+ v, ok := value.(bool)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetClaudeCodeOnly(v)
+ return nil
+ case group.FieldFallbackGroupID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetFallbackGroupID(v)
+ return nil
+ case group.FieldFallbackGroupIDOnInvalidRequest:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetFallbackGroupIDOnInvalidRequest(v)
+ return nil
+ case group.FieldModelRouting:
+ v, ok := value.(map[string][]int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetModelRouting(v)
+ return nil
+ case group.FieldModelRoutingEnabled:
+ v, ok := value.(bool)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetModelRoutingEnabled(v)
+ return nil
+ case group.FieldMcpXMLInject:
+ v, ok := value.(bool)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetMcpXMLInject(v)
+ return nil
+ case group.FieldSupportedModelScopes:
+ v, ok := value.([]string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetSupportedModelScopes(v)
+ return nil
+ case group.FieldSortOrder:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetSortOrder(v)
+ return nil
+ case group.FieldAllowMessagesDispatch:
+ v, ok := value.(bool)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetAllowMessagesDispatch(v)
+ return nil
+ case group.FieldRequireOauthOnly:
+ v, ok := value.(bool)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetRequireOauthOnly(v)
+ return nil
+ case group.FieldRequirePrivacySet:
+ v, ok := value.(bool)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetRequirePrivacySet(v)
+ return nil
+ case group.FieldDefaultMappedModel:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetDefaultMappedModel(v)
+ return nil
+ case group.FieldMessagesDispatchModelConfig:
+ v, ok := value.(domain.OpenAIMessagesDispatchModelConfig)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetMessagesDispatchModelConfig(v)
return nil
}
- return fmt.Errorf("unknown IdempotencyRecord field %s", name)
+ return fmt.Errorf("unknown Group field %s", name)
}
// AddedFields returns all numeric fields that were incremented/decremented during
// this mutation.
-func (m *IdempotencyRecordMutation) AddedFields() []string {
+func (m *GroupMutation) AddedFields() []string {
var fields []string
- if m.addresponse_status != nil {
- fields = append(fields, idempotencyrecord.FieldResponseStatus)
+ if m.addrate_multiplier != nil {
+ fields = append(fields, group.FieldRateMultiplier)
+ }
+ if m.adddaily_limit_usd != nil {
+ fields = append(fields, group.FieldDailyLimitUsd)
+ }
+ if m.addweekly_limit_usd != nil {
+ fields = append(fields, group.FieldWeeklyLimitUsd)
+ }
+ if m.addmonthly_limit_usd != nil {
+ fields = append(fields, group.FieldMonthlyLimitUsd)
+ }
+ if m.adddefault_validity_days != nil {
+ fields = append(fields, group.FieldDefaultValidityDays)
+ }
+ if m.addimage_price_1k != nil {
+ fields = append(fields, group.FieldImagePrice1k)
+ }
+ if m.addimage_price_2k != nil {
+ fields = append(fields, group.FieldImagePrice2k)
+ }
+ if m.addimage_price_4k != nil {
+ fields = append(fields, group.FieldImagePrice4k)
+ }
+ if m.addfallback_group_id != nil {
+ fields = append(fields, group.FieldFallbackGroupID)
+ }
+ if m.addfallback_group_id_on_invalid_request != nil {
+ fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest)
+ }
+ if m.addsort_order != nil {
+ fields = append(fields, group.FieldSortOrder)
}
return fields
}
@@ -12032,10 +12542,30 @@ func (m *IdempotencyRecordMutation) AddedFields() []string {
// AddedField returns the numeric value that was incremented/decremented on a field
// with the given name. The second boolean return value indicates that this field
// was not set, or was not defined in the schema.
-func (m *IdempotencyRecordMutation) AddedField(name string) (ent.Value, bool) {
+func (m *GroupMutation) AddedField(name string) (ent.Value, bool) {
switch name {
- case idempotencyrecord.FieldResponseStatus:
- return m.AddedResponseStatus()
+ case group.FieldRateMultiplier:
+ return m.AddedRateMultiplier()
+ case group.FieldDailyLimitUsd:
+ return m.AddedDailyLimitUsd()
+ case group.FieldWeeklyLimitUsd:
+ return m.AddedWeeklyLimitUsd()
+ case group.FieldMonthlyLimitUsd:
+ return m.AddedMonthlyLimitUsd()
+ case group.FieldDefaultValidityDays:
+ return m.AddedDefaultValidityDays()
+ case group.FieldImagePrice1k:
+ return m.AddedImagePrice1k()
+ case group.FieldImagePrice2k:
+ return m.AddedImagePrice2k()
+ case group.FieldImagePrice4k:
+ return m.AddedImagePrice4k()
+ case group.FieldFallbackGroupID:
+ return m.AddedFallbackGroupID()
+ case group.FieldFallbackGroupIDOnInvalidRequest:
+ return m.AddedFallbackGroupIDOnInvalidRequest()
+ case group.FieldSortOrder:
+ return m.AddedSortOrder()
}
return nil, false
}
@@ -12043,182 +12573,524 @@ func (m *IdempotencyRecordMutation) AddedField(name string) (ent.Value, bool) {
// AddField adds the value to the field with the given name. It returns an error if
// the field is not defined in the schema, or if the type mismatched the field
// type.
-func (m *IdempotencyRecordMutation) AddField(name string, value ent.Value) error {
+func (m *GroupMutation) AddField(name string, value ent.Value) error {
switch name {
- case idempotencyrecord.FieldResponseStatus:
- v, ok := value.(int)
+ case group.FieldRateMultiplier:
+ v, ok := value.(float64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
- m.AddResponseStatus(v)
+ m.AddRateMultiplier(v)
return nil
- }
- return fmt.Errorf("unknown IdempotencyRecord numeric field %s", name)
-}
-
-// ClearedFields returns all nullable fields that were cleared during this
-// mutation.
-func (m *IdempotencyRecordMutation) ClearedFields() []string {
- var fields []string
- if m.FieldCleared(idempotencyrecord.FieldResponseStatus) {
- fields = append(fields, idempotencyrecord.FieldResponseStatus)
- }
- if m.FieldCleared(idempotencyrecord.FieldResponseBody) {
- fields = append(fields, idempotencyrecord.FieldResponseBody)
- }
- if m.FieldCleared(idempotencyrecord.FieldErrorReason) {
- fields = append(fields, idempotencyrecord.FieldErrorReason)
- }
- if m.FieldCleared(idempotencyrecord.FieldLockedUntil) {
- fields = append(fields, idempotencyrecord.FieldLockedUntil)
- }
- return fields
-}
-
-// FieldCleared returns a boolean indicating if a field with the given name was
-// cleared in this mutation.
-func (m *IdempotencyRecordMutation) FieldCleared(name string) bool {
- _, ok := m.clearedFields[name]
- return ok
-}
-
-// ClearField clears the value of the field with the given name. It returns an
-// error if the field is not defined in the schema.
-func (m *IdempotencyRecordMutation) ClearField(name string) error {
- switch name {
- case idempotencyrecord.FieldResponseStatus:
- m.ClearResponseStatus()
+ case group.FieldDailyLimitUsd:
+ v, ok := value.(float64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddDailyLimitUsd(v)
return nil
- case idempotencyrecord.FieldResponseBody:
- m.ClearResponseBody()
+ case group.FieldWeeklyLimitUsd:
+ v, ok := value.(float64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddWeeklyLimitUsd(v)
return nil
- case idempotencyrecord.FieldErrorReason:
- m.ClearErrorReason()
+ case group.FieldMonthlyLimitUsd:
+ v, ok := value.(float64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddMonthlyLimitUsd(v)
return nil
- case idempotencyrecord.FieldLockedUntil:
- m.ClearLockedUntil()
+ case group.FieldDefaultValidityDays:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddDefaultValidityDays(v)
+ return nil
+ case group.FieldImagePrice1k:
+ v, ok := value.(float64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddImagePrice1k(v)
+ return nil
+ case group.FieldImagePrice2k:
+ v, ok := value.(float64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddImagePrice2k(v)
+ return nil
+ case group.FieldImagePrice4k:
+ v, ok := value.(float64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddImagePrice4k(v)
+ return nil
+ case group.FieldFallbackGroupID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddFallbackGroupID(v)
+ return nil
+ case group.FieldFallbackGroupIDOnInvalidRequest:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddFallbackGroupIDOnInvalidRequest(v)
+ return nil
+ case group.FieldSortOrder:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddSortOrder(v)
return nil
}
- return fmt.Errorf("unknown IdempotencyRecord nullable field %s", name)
+ return fmt.Errorf("unknown Group numeric field %s", name)
+}
+
+// ClearedFields returns all nullable fields that were cleared during this
+// mutation.
+func (m *GroupMutation) ClearedFields() []string {
+ var fields []string
+ if m.FieldCleared(group.FieldDeletedAt) {
+ fields = append(fields, group.FieldDeletedAt)
+ }
+ if m.FieldCleared(group.FieldDescription) {
+ fields = append(fields, group.FieldDescription)
+ }
+ if m.FieldCleared(group.FieldDailyLimitUsd) {
+ fields = append(fields, group.FieldDailyLimitUsd)
+ }
+ if m.FieldCleared(group.FieldWeeklyLimitUsd) {
+ fields = append(fields, group.FieldWeeklyLimitUsd)
+ }
+ if m.FieldCleared(group.FieldMonthlyLimitUsd) {
+ fields = append(fields, group.FieldMonthlyLimitUsd)
+ }
+ if m.FieldCleared(group.FieldImagePrice1k) {
+ fields = append(fields, group.FieldImagePrice1k)
+ }
+ if m.FieldCleared(group.FieldImagePrice2k) {
+ fields = append(fields, group.FieldImagePrice2k)
+ }
+ if m.FieldCleared(group.FieldImagePrice4k) {
+ fields = append(fields, group.FieldImagePrice4k)
+ }
+ if m.FieldCleared(group.FieldFallbackGroupID) {
+ fields = append(fields, group.FieldFallbackGroupID)
+ }
+ if m.FieldCleared(group.FieldFallbackGroupIDOnInvalidRequest) {
+ fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest)
+ }
+ if m.FieldCleared(group.FieldModelRouting) {
+ fields = append(fields, group.FieldModelRouting)
+ }
+ return fields
+}
+
+// FieldCleared returns a boolean indicating if a field with the given name was
+// cleared in this mutation.
+func (m *GroupMutation) FieldCleared(name string) bool {
+ _, ok := m.clearedFields[name]
+ return ok
+}
+
+// ClearField clears the value of the field with the given name. It returns an
+// error if the field is not defined in the schema.
+func (m *GroupMutation) ClearField(name string) error {
+ switch name {
+ case group.FieldDeletedAt:
+ m.ClearDeletedAt()
+ return nil
+ case group.FieldDescription:
+ m.ClearDescription()
+ return nil
+ case group.FieldDailyLimitUsd:
+ m.ClearDailyLimitUsd()
+ return nil
+ case group.FieldWeeklyLimitUsd:
+ m.ClearWeeklyLimitUsd()
+ return nil
+ case group.FieldMonthlyLimitUsd:
+ m.ClearMonthlyLimitUsd()
+ return nil
+ case group.FieldImagePrice1k:
+ m.ClearImagePrice1k()
+ return nil
+ case group.FieldImagePrice2k:
+ m.ClearImagePrice2k()
+ return nil
+ case group.FieldImagePrice4k:
+ m.ClearImagePrice4k()
+ return nil
+ case group.FieldFallbackGroupID:
+ m.ClearFallbackGroupID()
+ return nil
+ case group.FieldFallbackGroupIDOnInvalidRequest:
+ m.ClearFallbackGroupIDOnInvalidRequest()
+ return nil
+ case group.FieldModelRouting:
+ m.ClearModelRouting()
+ return nil
+ }
+ return fmt.Errorf("unknown Group nullable field %s", name)
}
// ResetField resets all changes in the mutation for the field with the given name.
// It returns an error if the field is not defined in the schema.
-func (m *IdempotencyRecordMutation) ResetField(name string) error {
+func (m *GroupMutation) ResetField(name string) error {
switch name {
- case idempotencyrecord.FieldCreatedAt:
+ case group.FieldCreatedAt:
m.ResetCreatedAt()
return nil
- case idempotencyrecord.FieldUpdatedAt:
+ case group.FieldUpdatedAt:
m.ResetUpdatedAt()
return nil
- case idempotencyrecord.FieldScope:
- m.ResetScope()
+ case group.FieldDeletedAt:
+ m.ResetDeletedAt()
return nil
- case idempotencyrecord.FieldIdempotencyKeyHash:
- m.ResetIdempotencyKeyHash()
+ case group.FieldName:
+ m.ResetName()
return nil
- case idempotencyrecord.FieldRequestFingerprint:
- m.ResetRequestFingerprint()
+ case group.FieldDescription:
+ m.ResetDescription()
return nil
- case idempotencyrecord.FieldStatus:
+ case group.FieldRateMultiplier:
+ m.ResetRateMultiplier()
+ return nil
+ case group.FieldIsExclusive:
+ m.ResetIsExclusive()
+ return nil
+ case group.FieldStatus:
m.ResetStatus()
return nil
- case idempotencyrecord.FieldResponseStatus:
- m.ResetResponseStatus()
+ case group.FieldPlatform:
+ m.ResetPlatform()
return nil
- case idempotencyrecord.FieldResponseBody:
- m.ResetResponseBody()
+ case group.FieldSubscriptionType:
+ m.ResetSubscriptionType()
return nil
- case idempotencyrecord.FieldErrorReason:
- m.ResetErrorReason()
+ case group.FieldDailyLimitUsd:
+ m.ResetDailyLimitUsd()
return nil
- case idempotencyrecord.FieldLockedUntil:
- m.ResetLockedUntil()
+ case group.FieldWeeklyLimitUsd:
+ m.ResetWeeklyLimitUsd()
return nil
- case idempotencyrecord.FieldExpiresAt:
- m.ResetExpiresAt()
+ case group.FieldMonthlyLimitUsd:
+ m.ResetMonthlyLimitUsd()
+ return nil
+ case group.FieldDefaultValidityDays:
+ m.ResetDefaultValidityDays()
+ return nil
+ case group.FieldImagePrice1k:
+ m.ResetImagePrice1k()
+ return nil
+ case group.FieldImagePrice2k:
+ m.ResetImagePrice2k()
+ return nil
+ case group.FieldImagePrice4k:
+ m.ResetImagePrice4k()
+ return nil
+ case group.FieldClaudeCodeOnly:
+ m.ResetClaudeCodeOnly()
+ return nil
+ case group.FieldFallbackGroupID:
+ m.ResetFallbackGroupID()
+ return nil
+ case group.FieldFallbackGroupIDOnInvalidRequest:
+ m.ResetFallbackGroupIDOnInvalidRequest()
+ return nil
+ case group.FieldModelRouting:
+ m.ResetModelRouting()
+ return nil
+ case group.FieldModelRoutingEnabled:
+ m.ResetModelRoutingEnabled()
+ return nil
+ case group.FieldMcpXMLInject:
+ m.ResetMcpXMLInject()
+ return nil
+ case group.FieldSupportedModelScopes:
+ m.ResetSupportedModelScopes()
+ return nil
+ case group.FieldSortOrder:
+ m.ResetSortOrder()
+ return nil
+ case group.FieldAllowMessagesDispatch:
+ m.ResetAllowMessagesDispatch()
+ return nil
+ case group.FieldRequireOauthOnly:
+ m.ResetRequireOauthOnly()
+ return nil
+ case group.FieldRequirePrivacySet:
+ m.ResetRequirePrivacySet()
+ return nil
+ case group.FieldDefaultMappedModel:
+ m.ResetDefaultMappedModel()
+ return nil
+ case group.FieldMessagesDispatchModelConfig:
+ m.ResetMessagesDispatchModelConfig()
return nil
}
- return fmt.Errorf("unknown IdempotencyRecord field %s", name)
+ return fmt.Errorf("unknown Group field %s", name)
}
// AddedEdges returns all edge names that were set/added in this mutation.
-func (m *IdempotencyRecordMutation) AddedEdges() []string {
- edges := make([]string, 0, 0)
- return edges
-}
-
-// AddedIDs returns all IDs (to other nodes) that were added for the given edge
-// name in this mutation.
-func (m *IdempotencyRecordMutation) AddedIDs(name string) []ent.Value {
- return nil
-}
+func (m *GroupMutation) AddedEdges() []string {
+ edges := make([]string, 0, 6)
+ if m.api_keys != nil {
+ edges = append(edges, group.EdgeAPIKeys)
+ }
+ if m.redeem_codes != nil {
+ edges = append(edges, group.EdgeRedeemCodes)
+ }
+ if m.subscriptions != nil {
+ edges = append(edges, group.EdgeSubscriptions)
+ }
+ if m.usage_logs != nil {
+ edges = append(edges, group.EdgeUsageLogs)
+ }
+ if m.accounts != nil {
+ edges = append(edges, group.EdgeAccounts)
+ }
+ if m.allowed_users != nil {
+ edges = append(edges, group.EdgeAllowedUsers)
+ }
+ return edges
+}
+
+// AddedIDs returns all IDs (to other nodes) that were added for the given edge
+// name in this mutation.
+func (m *GroupMutation) AddedIDs(name string) []ent.Value {
+ switch name {
+ case group.EdgeAPIKeys:
+ ids := make([]ent.Value, 0, len(m.api_keys))
+ for id := range m.api_keys {
+ ids = append(ids, id)
+ }
+ return ids
+ case group.EdgeRedeemCodes:
+ ids := make([]ent.Value, 0, len(m.redeem_codes))
+ for id := range m.redeem_codes {
+ ids = append(ids, id)
+ }
+ return ids
+ case group.EdgeSubscriptions:
+ ids := make([]ent.Value, 0, len(m.subscriptions))
+ for id := range m.subscriptions {
+ ids = append(ids, id)
+ }
+ return ids
+ case group.EdgeUsageLogs:
+ ids := make([]ent.Value, 0, len(m.usage_logs))
+ for id := range m.usage_logs {
+ ids = append(ids, id)
+ }
+ return ids
+ case group.EdgeAccounts:
+ ids := make([]ent.Value, 0, len(m.accounts))
+ for id := range m.accounts {
+ ids = append(ids, id)
+ }
+ return ids
+ case group.EdgeAllowedUsers:
+ ids := make([]ent.Value, 0, len(m.allowed_users))
+ for id := range m.allowed_users {
+ ids = append(ids, id)
+ }
+ return ids
+ }
+ return nil
+}
// RemovedEdges returns all edge names that were removed in this mutation.
-func (m *IdempotencyRecordMutation) RemovedEdges() []string {
- edges := make([]string, 0, 0)
+func (m *GroupMutation) RemovedEdges() []string {
+ edges := make([]string, 0, 6)
+ if m.removedapi_keys != nil {
+ edges = append(edges, group.EdgeAPIKeys)
+ }
+ if m.removedredeem_codes != nil {
+ edges = append(edges, group.EdgeRedeemCodes)
+ }
+ if m.removedsubscriptions != nil {
+ edges = append(edges, group.EdgeSubscriptions)
+ }
+ if m.removedusage_logs != nil {
+ edges = append(edges, group.EdgeUsageLogs)
+ }
+ if m.removedaccounts != nil {
+ edges = append(edges, group.EdgeAccounts)
+ }
+ if m.removedallowed_users != nil {
+ edges = append(edges, group.EdgeAllowedUsers)
+ }
return edges
}
// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
// the given name in this mutation.
-func (m *IdempotencyRecordMutation) RemovedIDs(name string) []ent.Value {
+func (m *GroupMutation) RemovedIDs(name string) []ent.Value {
+ switch name {
+ case group.EdgeAPIKeys:
+ ids := make([]ent.Value, 0, len(m.removedapi_keys))
+ for id := range m.removedapi_keys {
+ ids = append(ids, id)
+ }
+ return ids
+ case group.EdgeRedeemCodes:
+ ids := make([]ent.Value, 0, len(m.removedredeem_codes))
+ for id := range m.removedredeem_codes {
+ ids = append(ids, id)
+ }
+ return ids
+ case group.EdgeSubscriptions:
+ ids := make([]ent.Value, 0, len(m.removedsubscriptions))
+ for id := range m.removedsubscriptions {
+ ids = append(ids, id)
+ }
+ return ids
+ case group.EdgeUsageLogs:
+ ids := make([]ent.Value, 0, len(m.removedusage_logs))
+ for id := range m.removedusage_logs {
+ ids = append(ids, id)
+ }
+ return ids
+ case group.EdgeAccounts:
+ ids := make([]ent.Value, 0, len(m.removedaccounts))
+ for id := range m.removedaccounts {
+ ids = append(ids, id)
+ }
+ return ids
+ case group.EdgeAllowedUsers:
+ ids := make([]ent.Value, 0, len(m.removedallowed_users))
+ for id := range m.removedallowed_users {
+ ids = append(ids, id)
+ }
+ return ids
+ }
return nil
}
// ClearedEdges returns all edge names that were cleared in this mutation.
-func (m *IdempotencyRecordMutation) ClearedEdges() []string {
- edges := make([]string, 0, 0)
+func (m *GroupMutation) ClearedEdges() []string {
+ edges := make([]string, 0, 6)
+ if m.clearedapi_keys {
+ edges = append(edges, group.EdgeAPIKeys)
+ }
+ if m.clearedredeem_codes {
+ edges = append(edges, group.EdgeRedeemCodes)
+ }
+ if m.clearedsubscriptions {
+ edges = append(edges, group.EdgeSubscriptions)
+ }
+ if m.clearedusage_logs {
+ edges = append(edges, group.EdgeUsageLogs)
+ }
+ if m.clearedaccounts {
+ edges = append(edges, group.EdgeAccounts)
+ }
+ if m.clearedallowed_users {
+ edges = append(edges, group.EdgeAllowedUsers)
+ }
return edges
}
// EdgeCleared returns a boolean which indicates if the edge with the given name
// was cleared in this mutation.
-func (m *IdempotencyRecordMutation) EdgeCleared(name string) bool {
+func (m *GroupMutation) EdgeCleared(name string) bool {
+ switch name {
+ case group.EdgeAPIKeys:
+ return m.clearedapi_keys
+ case group.EdgeRedeemCodes:
+ return m.clearedredeem_codes
+ case group.EdgeSubscriptions:
+ return m.clearedsubscriptions
+ case group.EdgeUsageLogs:
+ return m.clearedusage_logs
+ case group.EdgeAccounts:
+ return m.clearedaccounts
+ case group.EdgeAllowedUsers:
+ return m.clearedallowed_users
+ }
return false
}
// ClearEdge clears the value of the edge with the given name. It returns an error
// if that edge is not defined in the schema.
-func (m *IdempotencyRecordMutation) ClearEdge(name string) error {
- return fmt.Errorf("unknown IdempotencyRecord unique edge %s", name)
+func (m *GroupMutation) ClearEdge(name string) error {
+ switch name {
+ }
+ return fmt.Errorf("unknown Group unique edge %s", name)
}
// ResetEdge resets all changes to the edge with the given name in this mutation.
// It returns an error if the edge is not defined in the schema.
-func (m *IdempotencyRecordMutation) ResetEdge(name string) error {
- return fmt.Errorf("unknown IdempotencyRecord edge %s", name)
+func (m *GroupMutation) ResetEdge(name string) error {
+ switch name {
+ case group.EdgeAPIKeys:
+ m.ResetAPIKeys()
+ return nil
+ case group.EdgeRedeemCodes:
+ m.ResetRedeemCodes()
+ return nil
+ case group.EdgeSubscriptions:
+ m.ResetSubscriptions()
+ return nil
+ case group.EdgeUsageLogs:
+ m.ResetUsageLogs()
+ return nil
+ case group.EdgeAccounts:
+ m.ResetAccounts()
+ return nil
+ case group.EdgeAllowedUsers:
+ m.ResetAllowedUsers()
+ return nil
+ }
+ return fmt.Errorf("unknown Group edge %s", name)
}
-// PaymentAuditLogMutation represents an operation that mutates the PaymentAuditLog nodes in the graph.
-type PaymentAuditLogMutation struct {
+// IdempotencyRecordMutation represents an operation that mutates the IdempotencyRecord nodes in the graph.
+type IdempotencyRecordMutation struct {
config
- op Op
- typ string
- id *int64
- order_id *string
- action *string
- detail *string
- operator *string
- created_at *time.Time
- clearedFields map[string]struct{}
- done bool
- oldValue func(context.Context) (*PaymentAuditLog, error)
- predicates []predicate.PaymentAuditLog
+ op Op
+ typ string
+ id *int64
+ created_at *time.Time
+ updated_at *time.Time
+ scope *string
+ idempotency_key_hash *string
+ request_fingerprint *string
+ status *string
+ response_status *int
+ addresponse_status *int
+ response_body *string
+ error_reason *string
+ locked_until *time.Time
+ expires_at *time.Time
+ clearedFields map[string]struct{}
+ done bool
+ oldValue func(context.Context) (*IdempotencyRecord, error)
+ predicates []predicate.IdempotencyRecord
}
-var _ ent.Mutation = (*PaymentAuditLogMutation)(nil)
+var _ ent.Mutation = (*IdempotencyRecordMutation)(nil)
-// paymentauditlogOption allows management of the mutation configuration using functional options.
-type paymentauditlogOption func(*PaymentAuditLogMutation)
+// idempotencyrecordOption allows management of the mutation configuration using functional options.
+type idempotencyrecordOption func(*IdempotencyRecordMutation)
-// newPaymentAuditLogMutation creates new mutation for the PaymentAuditLog entity.
-func newPaymentAuditLogMutation(c config, op Op, opts ...paymentauditlogOption) *PaymentAuditLogMutation {
- m := &PaymentAuditLogMutation{
+// newIdempotencyRecordMutation creates new mutation for the IdempotencyRecord entity.
+func newIdempotencyRecordMutation(c config, op Op, opts ...idempotencyrecordOption) *IdempotencyRecordMutation {
+ m := &IdempotencyRecordMutation{
config: c,
op: op,
- typ: TypePaymentAuditLog,
+ typ: TypeIdempotencyRecord,
clearedFields: make(map[string]struct{}),
}
for _, opt := range opts {
@@ -12227,20 +13099,20 @@ func newPaymentAuditLogMutation(c config, op Op, opts ...paymentauditlogOption)
return m
}
-// withPaymentAuditLogID sets the ID field of the mutation.
-func withPaymentAuditLogID(id int64) paymentauditlogOption {
- return func(m *PaymentAuditLogMutation) {
+// withIdempotencyRecordID sets the ID field of the mutation.
+func withIdempotencyRecordID(id int64) idempotencyrecordOption {
+ return func(m *IdempotencyRecordMutation) {
var (
err error
once sync.Once
- value *PaymentAuditLog
+ value *IdempotencyRecord
)
- m.oldValue = func(ctx context.Context) (*PaymentAuditLog, error) {
+ m.oldValue = func(ctx context.Context) (*IdempotencyRecord, error) {
once.Do(func() {
if m.done {
err = errors.New("querying old values post mutation is not allowed")
} else {
- value, err = m.Client().PaymentAuditLog.Get(ctx, id)
+ value, err = m.Client().IdempotencyRecord.Get(ctx, id)
}
})
return value, err
@@ -12249,10 +13121,10 @@ func withPaymentAuditLogID(id int64) paymentauditlogOption {
}
}
-// withPaymentAuditLog sets the old PaymentAuditLog of the mutation.
-func withPaymentAuditLog(node *PaymentAuditLog) paymentauditlogOption {
- return func(m *PaymentAuditLogMutation) {
- m.oldValue = func(context.Context) (*PaymentAuditLog, error) {
+// withIdempotencyRecord sets the old IdempotencyRecord of the mutation.
+func withIdempotencyRecord(node *IdempotencyRecord) idempotencyrecordOption {
+ return func(m *IdempotencyRecordMutation) {
+ m.oldValue = func(context.Context) (*IdempotencyRecord, error) {
return node, nil
}
m.id = &node.ID
@@ -12261,7 +13133,7 @@ func withPaymentAuditLog(node *PaymentAuditLog) paymentauditlogOption {
// Client returns a new `ent.Client` from the mutation. If the mutation was
// executed in a transaction (ent.Tx), a transactional client is returned.
-func (m PaymentAuditLogMutation) Client() *Client {
+func (m IdempotencyRecordMutation) Client() *Client {
client := &Client{config: m.config}
client.init()
return client
@@ -12269,7 +13141,7 @@ func (m PaymentAuditLogMutation) Client() *Client {
// Tx returns an `ent.Tx` for mutations that were executed in transactions;
// it returns an error otherwise.
-func (m PaymentAuditLogMutation) Tx() (*Tx, error) {
+func (m IdempotencyRecordMutation) Tx() (*Tx, error) {
if _, ok := m.driver.(*txDriver); !ok {
return nil, errors.New("ent: mutation is not running in a transaction")
}
@@ -12280,7 +13152,7 @@ func (m PaymentAuditLogMutation) Tx() (*Tx, error) {
// ID returns the ID value in the mutation. Note that the ID is only available
// if it was provided to the builder or after it was returned from the database.
-func (m *PaymentAuditLogMutation) ID() (id int64, exists bool) {
+func (m *IdempotencyRecordMutation) ID() (id int64, exists bool) {
if m.id == nil {
return
}
@@ -12291,7 +13163,7 @@ func (m *PaymentAuditLogMutation) ID() (id int64, exists bool) {
// That means, if the mutation is applied within a transaction with an isolation level such
// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
// or updated by the mutation.
-func (m *PaymentAuditLogMutation) IDs(ctx context.Context) ([]int64, error) {
+func (m *IdempotencyRecordMutation) IDs(ctx context.Context) ([]int64, error) {
switch {
case m.op.Is(OpUpdateOne | OpDeleteOne):
id, exists := m.ID()
@@ -12300,3381 +13172,6098 @@ func (m *PaymentAuditLogMutation) IDs(ctx context.Context) ([]int64, error) {
}
fallthrough
case m.op.Is(OpUpdate | OpDelete):
- return m.Client().PaymentAuditLog.Query().Where(m.predicates...).IDs(ctx)
+ return m.Client().IdempotencyRecord.Query().Where(m.predicates...).IDs(ctx)
default:
return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
}
}
-// SetOrderID sets the "order_id" field.
-func (m *PaymentAuditLogMutation) SetOrderID(s string) {
- m.order_id = &s
+// SetCreatedAt sets the "created_at" field.
+func (m *IdempotencyRecordMutation) SetCreatedAt(t time.Time) {
+ m.created_at = &t
}
-// OrderID returns the value of the "order_id" field in the mutation.
-func (m *PaymentAuditLogMutation) OrderID() (r string, exists bool) {
- v := m.order_id
+// CreatedAt returns the value of the "created_at" field in the mutation.
+func (m *IdempotencyRecordMutation) CreatedAt() (r time.Time, exists bool) {
+ v := m.created_at
if v == nil {
return
}
return *v, true
}
-// OldOrderID returns the old "order_id" field's value of the PaymentAuditLog entity.
-// If the PaymentAuditLog object wasn't provided to the builder, the object is fetched from the database.
+// OldCreatedAt returns the old "created_at" field's value of the IdempotencyRecord entity.
+// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentAuditLogMutation) OldOrderID(ctx context.Context) (v string, err error) {
+func (m *IdempotencyRecordMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldOrderID is only allowed on UpdateOne operations")
+ return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldOrderID requires an ID field in the mutation")
+ return v, errors.New("OldCreatedAt requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldOrderID: %w", err)
+ return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err)
}
- return oldValue.OrderID, nil
+ return oldValue.CreatedAt, nil
}
-// ResetOrderID resets all changes to the "order_id" field.
-func (m *PaymentAuditLogMutation) ResetOrderID() {
- m.order_id = nil
+// ResetCreatedAt resets all changes to the "created_at" field.
+func (m *IdempotencyRecordMutation) ResetCreatedAt() {
+ m.created_at = nil
}
-// SetAction sets the "action" field.
-func (m *PaymentAuditLogMutation) SetAction(s string) {
- m.action = &s
+// SetUpdatedAt sets the "updated_at" field.
+func (m *IdempotencyRecordMutation) SetUpdatedAt(t time.Time) {
+ m.updated_at = &t
}
-// Action returns the value of the "action" field in the mutation.
-func (m *PaymentAuditLogMutation) Action() (r string, exists bool) {
- v := m.action
+// UpdatedAt returns the value of the "updated_at" field in the mutation.
+func (m *IdempotencyRecordMutation) UpdatedAt() (r time.Time, exists bool) {
+ v := m.updated_at
if v == nil {
return
}
return *v, true
}
-// OldAction returns the old "action" field's value of the PaymentAuditLog entity.
-// If the PaymentAuditLog object wasn't provided to the builder, the object is fetched from the database.
+// OldUpdatedAt returns the old "updated_at" field's value of the IdempotencyRecord entity.
+// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentAuditLogMutation) OldAction(ctx context.Context) (v string, err error) {
+func (m *IdempotencyRecordMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldAction is only allowed on UpdateOne operations")
+ return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldAction requires an ID field in the mutation")
+ return v, errors.New("OldUpdatedAt requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldAction: %w", err)
+ return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err)
}
- return oldValue.Action, nil
+ return oldValue.UpdatedAt, nil
}
-// ResetAction resets all changes to the "action" field.
-func (m *PaymentAuditLogMutation) ResetAction() {
- m.action = nil
+// ResetUpdatedAt resets all changes to the "updated_at" field.
+func (m *IdempotencyRecordMutation) ResetUpdatedAt() {
+ m.updated_at = nil
}
-// SetDetail sets the "detail" field.
-func (m *PaymentAuditLogMutation) SetDetail(s string) {
- m.detail = &s
+// SetScope sets the "scope" field.
+func (m *IdempotencyRecordMutation) SetScope(s string) {
+ m.scope = &s
}
-// Detail returns the value of the "detail" field in the mutation.
-func (m *PaymentAuditLogMutation) Detail() (r string, exists bool) {
- v := m.detail
+// Scope returns the value of the "scope" field in the mutation.
+func (m *IdempotencyRecordMutation) Scope() (r string, exists bool) {
+ v := m.scope
if v == nil {
return
}
return *v, true
}
-// OldDetail returns the old "detail" field's value of the PaymentAuditLog entity.
-// If the PaymentAuditLog object wasn't provided to the builder, the object is fetched from the database.
+// OldScope returns the old "scope" field's value of the IdempotencyRecord entity.
+// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentAuditLogMutation) OldDetail(ctx context.Context) (v string, err error) {
+func (m *IdempotencyRecordMutation) OldScope(ctx context.Context) (v string, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldDetail is only allowed on UpdateOne operations")
+ return v, errors.New("OldScope is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldDetail requires an ID field in the mutation")
+ return v, errors.New("OldScope requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldDetail: %w", err)
+ return v, fmt.Errorf("querying old value for OldScope: %w", err)
}
- return oldValue.Detail, nil
+ return oldValue.Scope, nil
}
-// ResetDetail resets all changes to the "detail" field.
-func (m *PaymentAuditLogMutation) ResetDetail() {
- m.detail = nil
+// ResetScope resets all changes to the "scope" field.
+func (m *IdempotencyRecordMutation) ResetScope() {
+ m.scope = nil
}
-// SetOperator sets the "operator" field.
-func (m *PaymentAuditLogMutation) SetOperator(s string) {
- m.operator = &s
+// SetIdempotencyKeyHash sets the "idempotency_key_hash" field.
+func (m *IdempotencyRecordMutation) SetIdempotencyKeyHash(s string) {
+ m.idempotency_key_hash = &s
}
-// Operator returns the value of the "operator" field in the mutation.
-func (m *PaymentAuditLogMutation) Operator() (r string, exists bool) {
- v := m.operator
+// IdempotencyKeyHash returns the value of the "idempotency_key_hash" field in the mutation.
+func (m *IdempotencyRecordMutation) IdempotencyKeyHash() (r string, exists bool) {
+ v := m.idempotency_key_hash
if v == nil {
return
}
return *v, true
}
-// OldOperator returns the old "operator" field's value of the PaymentAuditLog entity.
-// If the PaymentAuditLog object wasn't provided to the builder, the object is fetched from the database.
+// OldIdempotencyKeyHash returns the old "idempotency_key_hash" field's value of the IdempotencyRecord entity.
+// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentAuditLogMutation) OldOperator(ctx context.Context) (v string, err error) {
+func (m *IdempotencyRecordMutation) OldIdempotencyKeyHash(ctx context.Context) (v string, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldOperator is only allowed on UpdateOne operations")
+ return v, errors.New("OldIdempotencyKeyHash is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldOperator requires an ID field in the mutation")
+ return v, errors.New("OldIdempotencyKeyHash requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldOperator: %w", err)
+ return v, fmt.Errorf("querying old value for OldIdempotencyKeyHash: %w", err)
}
- return oldValue.Operator, nil
+ return oldValue.IdempotencyKeyHash, nil
}
-// ResetOperator resets all changes to the "operator" field.
-func (m *PaymentAuditLogMutation) ResetOperator() {
- m.operator = nil
+// ResetIdempotencyKeyHash resets all changes to the "idempotency_key_hash" field.
+func (m *IdempotencyRecordMutation) ResetIdempotencyKeyHash() {
+ m.idempotency_key_hash = nil
}
-// SetCreatedAt sets the "created_at" field.
-func (m *PaymentAuditLogMutation) SetCreatedAt(t time.Time) {
- m.created_at = &t
+// SetRequestFingerprint sets the "request_fingerprint" field.
+func (m *IdempotencyRecordMutation) SetRequestFingerprint(s string) {
+ m.request_fingerprint = &s
}
-// CreatedAt returns the value of the "created_at" field in the mutation.
-func (m *PaymentAuditLogMutation) CreatedAt() (r time.Time, exists bool) {
- v := m.created_at
+// RequestFingerprint returns the value of the "request_fingerprint" field in the mutation.
+func (m *IdempotencyRecordMutation) RequestFingerprint() (r string, exists bool) {
+ v := m.request_fingerprint
if v == nil {
return
}
return *v, true
}
-// OldCreatedAt returns the old "created_at" field's value of the PaymentAuditLog entity.
-// If the PaymentAuditLog object wasn't provided to the builder, the object is fetched from the database.
+// OldRequestFingerprint returns the old "request_fingerprint" field's value of the IdempotencyRecord entity.
+// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentAuditLogMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
+func (m *IdempotencyRecordMutation) OldRequestFingerprint(ctx context.Context) (v string, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations")
+ return v, errors.New("OldRequestFingerprint is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldCreatedAt requires an ID field in the mutation")
+ return v, errors.New("OldRequestFingerprint requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err)
+ return v, fmt.Errorf("querying old value for OldRequestFingerprint: %w", err)
}
- return oldValue.CreatedAt, nil
+ return oldValue.RequestFingerprint, nil
}
-// ResetCreatedAt resets all changes to the "created_at" field.
-func (m *PaymentAuditLogMutation) ResetCreatedAt() {
- m.created_at = nil
+// ResetRequestFingerprint resets all changes to the "request_fingerprint" field.
+func (m *IdempotencyRecordMutation) ResetRequestFingerprint() {
+ m.request_fingerprint = nil
}
-// Where appends a list predicates to the PaymentAuditLogMutation builder.
-func (m *PaymentAuditLogMutation) Where(ps ...predicate.PaymentAuditLog) {
- m.predicates = append(m.predicates, ps...)
+// SetStatus sets the "status" field.
+func (m *IdempotencyRecordMutation) SetStatus(s string) {
+ m.status = &s
}
-// WhereP appends storage-level predicates to the PaymentAuditLogMutation builder. Using this method,
-// users can use type-assertion to append predicates that do not depend on any generated package.
-func (m *PaymentAuditLogMutation) WhereP(ps ...func(*sql.Selector)) {
- p := make([]predicate.PaymentAuditLog, len(ps))
- for i := range ps {
- p[i] = ps[i]
+// Status returns the value of the "status" field in the mutation.
+func (m *IdempotencyRecordMutation) Status() (r string, exists bool) {
+ v := m.status
+ if v == nil {
+ return
}
- m.Where(p...)
+ return *v, true
}
-// Op returns the operation name.
-func (m *PaymentAuditLogMutation) Op() Op {
- return m.op
+// OldStatus returns the old "status" field's value of the IdempotencyRecord entity.
+// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *IdempotencyRecordMutation) OldStatus(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldStatus is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldStatus requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldStatus: %w", err)
+ }
+ return oldValue.Status, nil
}
-// SetOp allows setting the mutation operation.
-func (m *PaymentAuditLogMutation) SetOp(op Op) {
- m.op = op
+// ResetStatus resets all changes to the "status" field.
+func (m *IdempotencyRecordMutation) ResetStatus() {
+ m.status = nil
}
-// Type returns the node type of this mutation (PaymentAuditLog).
-func (m *PaymentAuditLogMutation) Type() string {
- return m.typ
+// SetResponseStatus sets the "response_status" field.
+func (m *IdempotencyRecordMutation) SetResponseStatus(i int) {
+ m.response_status = &i
+ m.addresponse_status = nil
}
-// Fields returns all fields that were changed during this mutation. Note that in
-// order to get all numeric fields that were incremented/decremented, call
-// AddedFields().
-func (m *PaymentAuditLogMutation) Fields() []string {
- fields := make([]string, 0, 5)
- if m.order_id != nil {
- fields = append(fields, paymentauditlog.FieldOrderID)
- }
- if m.action != nil {
- fields = append(fields, paymentauditlog.FieldAction)
- }
- if m.detail != nil {
- fields = append(fields, paymentauditlog.FieldDetail)
- }
- if m.operator != nil {
- fields = append(fields, paymentauditlog.FieldOperator)
- }
- if m.created_at != nil {
- fields = append(fields, paymentauditlog.FieldCreatedAt)
+// ResponseStatus returns the value of the "response_status" field in the mutation.
+func (m *IdempotencyRecordMutation) ResponseStatus() (r int, exists bool) {
+ v := m.response_status
+ if v == nil {
+ return
}
- return fields
+ return *v, true
}
-// Field returns the value of a field with the given name. The second boolean
-// return value indicates that this field was not set, or was not defined in the
-// schema.
-func (m *PaymentAuditLogMutation) Field(name string) (ent.Value, bool) {
- switch name {
- case paymentauditlog.FieldOrderID:
- return m.OrderID()
- case paymentauditlog.FieldAction:
- return m.Action()
- case paymentauditlog.FieldDetail:
- return m.Detail()
- case paymentauditlog.FieldOperator:
- return m.Operator()
- case paymentauditlog.FieldCreatedAt:
- return m.CreatedAt()
+// OldResponseStatus returns the old "response_status" field's value of the IdempotencyRecord entity.
+// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *IdempotencyRecordMutation) OldResponseStatus(ctx context.Context) (v *int, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldResponseStatus is only allowed on UpdateOne operations")
}
- return nil, false
-}
-
-// OldField returns the old value of the field from the database. An error is
-// returned if the mutation operation is not UpdateOne, or the query to the
-// database failed.
-func (m *PaymentAuditLogMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
- switch name {
- case paymentauditlog.FieldOrderID:
- return m.OldOrderID(ctx)
- case paymentauditlog.FieldAction:
- return m.OldAction(ctx)
- case paymentauditlog.FieldDetail:
- return m.OldDetail(ctx)
- case paymentauditlog.FieldOperator:
- return m.OldOperator(ctx)
- case paymentauditlog.FieldCreatedAt:
- return m.OldCreatedAt(ctx)
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldResponseStatus requires an ID field in the mutation")
}
- return nil, fmt.Errorf("unknown PaymentAuditLog field %s", name)
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldResponseStatus: %w", err)
+ }
+ return oldValue.ResponseStatus, nil
}
-// SetField sets the value of a field with the given name. It returns an error if
-// the field is not defined in the schema, or if the type mismatched the field
-// type.
-func (m *PaymentAuditLogMutation) SetField(name string, value ent.Value) error {
- switch name {
- case paymentauditlog.FieldOrderID:
- v, ok := value.(string)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetOrderID(v)
- return nil
- case paymentauditlog.FieldAction:
- v, ok := value.(string)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetAction(v)
- return nil
- case paymentauditlog.FieldDetail:
- v, ok := value.(string)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetDetail(v)
- return nil
- case paymentauditlog.FieldOperator:
- v, ok := value.(string)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetOperator(v)
- return nil
- case paymentauditlog.FieldCreatedAt:
- v, ok := value.(time.Time)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetCreatedAt(v)
- return nil
+// AddResponseStatus adds i to the "response_status" field.
+func (m *IdempotencyRecordMutation) AddResponseStatus(i int) {
+ if m.addresponse_status != nil {
+ *m.addresponse_status += i
+ } else {
+ m.addresponse_status = &i
}
- return fmt.Errorf("unknown PaymentAuditLog field %s", name)
}
-// AddedFields returns all numeric fields that were incremented/decremented during
-// this mutation.
-func (m *PaymentAuditLogMutation) AddedFields() []string {
- return nil
+// AddedResponseStatus returns the value that was added to the "response_status" field in this mutation.
+func (m *IdempotencyRecordMutation) AddedResponseStatus() (r int, exists bool) {
+ v := m.addresponse_status
+ if v == nil {
+ return
+ }
+ return *v, true
}
-// AddedField returns the numeric value that was incremented/decremented on a field
-// with the given name. The second boolean return value indicates that this field
-// was not set, or was not defined in the schema.
-func (m *PaymentAuditLogMutation) AddedField(name string) (ent.Value, bool) {
- return nil, false
+// ClearResponseStatus clears the value of the "response_status" field.
+func (m *IdempotencyRecordMutation) ClearResponseStatus() {
+ m.response_status = nil
+ m.addresponse_status = nil
+ m.clearedFields[idempotencyrecord.FieldResponseStatus] = struct{}{}
}
-// AddField adds the value to the field with the given name. It returns an error if
-// the field is not defined in the schema, or if the type mismatched the field
-// type.
-func (m *PaymentAuditLogMutation) AddField(name string, value ent.Value) error {
- switch name {
- }
- return fmt.Errorf("unknown PaymentAuditLog numeric field %s", name)
+// ResponseStatusCleared returns if the "response_status" field was cleared in this mutation.
+func (m *IdempotencyRecordMutation) ResponseStatusCleared() bool {
+ _, ok := m.clearedFields[idempotencyrecord.FieldResponseStatus]
+ return ok
}
-// ClearedFields returns all nullable fields that were cleared during this
-// mutation.
-func (m *PaymentAuditLogMutation) ClearedFields() []string {
- return nil
+// ResetResponseStatus resets all changes to the "response_status" field.
+func (m *IdempotencyRecordMutation) ResetResponseStatus() {
+ m.response_status = nil
+ m.addresponse_status = nil
+ delete(m.clearedFields, idempotencyrecord.FieldResponseStatus)
}
-// FieldCleared returns a boolean indicating if a field with the given name was
-// cleared in this mutation.
-func (m *PaymentAuditLogMutation) FieldCleared(name string) bool {
- _, ok := m.clearedFields[name]
- return ok
+// SetResponseBody sets the "response_body" field.
+func (m *IdempotencyRecordMutation) SetResponseBody(s string) {
+ m.response_body = &s
}
-// ClearField clears the value of the field with the given name. It returns an
-// error if the field is not defined in the schema.
-func (m *PaymentAuditLogMutation) ClearField(name string) error {
- return fmt.Errorf("unknown PaymentAuditLog nullable field %s", name)
+// ResponseBody returns the value of the "response_body" field in the mutation.
+func (m *IdempotencyRecordMutation) ResponseBody() (r string, exists bool) {
+ v := m.response_body
+ if v == nil {
+ return
+ }
+ return *v, true
}
-// ResetField resets all changes in the mutation for the field with the given name.
-// It returns an error if the field is not defined in the schema.
-func (m *PaymentAuditLogMutation) ResetField(name string) error {
- switch name {
- case paymentauditlog.FieldOrderID:
- m.ResetOrderID()
- return nil
- case paymentauditlog.FieldAction:
- m.ResetAction()
- return nil
- case paymentauditlog.FieldDetail:
- m.ResetDetail()
- return nil
- case paymentauditlog.FieldOperator:
- m.ResetOperator()
- return nil
- case paymentauditlog.FieldCreatedAt:
- m.ResetCreatedAt()
- return nil
+// OldResponseBody returns the old "response_body" field's value of the IdempotencyRecord entity.
+// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *IdempotencyRecordMutation) OldResponseBody(ctx context.Context) (v *string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldResponseBody is only allowed on UpdateOne operations")
}
- return fmt.Errorf("unknown PaymentAuditLog field %s", name)
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldResponseBody requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldResponseBody: %w", err)
+ }
+ return oldValue.ResponseBody, nil
}
-// AddedEdges returns all edge names that were set/added in this mutation.
-func (m *PaymentAuditLogMutation) AddedEdges() []string {
- edges := make([]string, 0, 0)
- return edges
+// ClearResponseBody clears the value of the "response_body" field.
+func (m *IdempotencyRecordMutation) ClearResponseBody() {
+ m.response_body = nil
+ m.clearedFields[idempotencyrecord.FieldResponseBody] = struct{}{}
}
-// AddedIDs returns all IDs (to other nodes) that were added for the given edge
-// name in this mutation.
-func (m *PaymentAuditLogMutation) AddedIDs(name string) []ent.Value {
- return nil
+// ResponseBodyCleared returns if the "response_body" field was cleared in this mutation.
+func (m *IdempotencyRecordMutation) ResponseBodyCleared() bool {
+ _, ok := m.clearedFields[idempotencyrecord.FieldResponseBody]
+ return ok
}
-// RemovedEdges returns all edge names that were removed in this mutation.
-func (m *PaymentAuditLogMutation) RemovedEdges() []string {
- edges := make([]string, 0, 0)
- return edges
+// ResetResponseBody resets all changes to the "response_body" field.
+func (m *IdempotencyRecordMutation) ResetResponseBody() {
+ m.response_body = nil
+ delete(m.clearedFields, idempotencyrecord.FieldResponseBody)
}
-// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
-// the given name in this mutation.
-func (m *PaymentAuditLogMutation) RemovedIDs(name string) []ent.Value {
- return nil
+// SetErrorReason sets the "error_reason" field.
+func (m *IdempotencyRecordMutation) SetErrorReason(s string) {
+ m.error_reason = &s
}
-// ClearedEdges returns all edge names that were cleared in this mutation.
-func (m *PaymentAuditLogMutation) ClearedEdges() []string {
- edges := make([]string, 0, 0)
- return edges
+// ErrorReason returns the value of the "error_reason" field in the mutation.
+func (m *IdempotencyRecordMutation) ErrorReason() (r string, exists bool) {
+ v := m.error_reason
+ if v == nil {
+ return
+ }
+ return *v, true
}
-// EdgeCleared returns a boolean which indicates if the edge with the given name
-// was cleared in this mutation.
-func (m *PaymentAuditLogMutation) EdgeCleared(name string) bool {
- return false
+// OldErrorReason returns the old "error_reason" field's value of the IdempotencyRecord entity.
+// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *IdempotencyRecordMutation) OldErrorReason(ctx context.Context) (v *string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldErrorReason is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldErrorReason requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldErrorReason: %w", err)
+ }
+ return oldValue.ErrorReason, nil
}
-// ClearEdge clears the value of the edge with the given name. It returns an error
-// if that edge is not defined in the schema.
-func (m *PaymentAuditLogMutation) ClearEdge(name string) error {
- return fmt.Errorf("unknown PaymentAuditLog unique edge %s", name)
+// ClearErrorReason clears the value of the "error_reason" field.
+func (m *IdempotencyRecordMutation) ClearErrorReason() {
+ m.error_reason = nil
+ m.clearedFields[idempotencyrecord.FieldErrorReason] = struct{}{}
}
-// ResetEdge resets all changes to the edge with the given name in this mutation.
-// It returns an error if the edge is not defined in the schema.
-func (m *PaymentAuditLogMutation) ResetEdge(name string) error {
- return fmt.Errorf("unknown PaymentAuditLog edge %s", name)
-}
-
-// PaymentOrderMutation represents an operation that mutates the PaymentOrder nodes in the graph.
-type PaymentOrderMutation struct {
- config
- op Op
- typ string
- id *int64
- user_email *string
- user_name *string
- user_notes *string
- amount *float64
- addamount *float64
- pay_amount *float64
- addpay_amount *float64
- fee_rate *float64
- addfee_rate *float64
- recharge_code *string
- out_trade_no *string
- payment_type *string
- payment_trade_no *string
- pay_url *string
- qr_code *string
- qr_code_img *string
- order_type *string
- plan_id *int64
- addplan_id *int64
- subscription_group_id *int64
- addsubscription_group_id *int64
- subscription_days *int
- addsubscription_days *int
- provider_instance_id *string
- status *string
- refund_amount *float64
- addrefund_amount *float64
- refund_reason *string
- refund_at *time.Time
- force_refund *bool
- refund_requested_at *time.Time
- refund_request_reason *string
- refund_requested_by *string
- expires_at *time.Time
- paid_at *time.Time
- completed_at *time.Time
- failed_at *time.Time
- failed_reason *string
- client_ip *string
- src_host *string
- src_url *string
- created_at *time.Time
- updated_at *time.Time
- clearedFields map[string]struct{}
- user *int64
- cleareduser bool
- done bool
- oldValue func(context.Context) (*PaymentOrder, error)
- predicates []predicate.PaymentOrder
-}
-
-var _ ent.Mutation = (*PaymentOrderMutation)(nil)
-
-// paymentorderOption allows management of the mutation configuration using functional options.
-type paymentorderOption func(*PaymentOrderMutation)
-
-// newPaymentOrderMutation creates new mutation for the PaymentOrder entity.
-func newPaymentOrderMutation(c config, op Op, opts ...paymentorderOption) *PaymentOrderMutation {
- m := &PaymentOrderMutation{
- config: c,
- op: op,
- typ: TypePaymentOrder,
- clearedFields: make(map[string]struct{}),
- }
- for _, opt := range opts {
- opt(m)
- }
- return m
-}
-
-// withPaymentOrderID sets the ID field of the mutation.
-func withPaymentOrderID(id int64) paymentorderOption {
- return func(m *PaymentOrderMutation) {
- var (
- err error
- once sync.Once
- value *PaymentOrder
- )
- m.oldValue = func(ctx context.Context) (*PaymentOrder, error) {
- once.Do(func() {
- if m.done {
- err = errors.New("querying old values post mutation is not allowed")
- } else {
- value, err = m.Client().PaymentOrder.Get(ctx, id)
- }
- })
- return value, err
- }
- m.id = &id
- }
-}
-
-// withPaymentOrder sets the old PaymentOrder of the mutation.
-func withPaymentOrder(node *PaymentOrder) paymentorderOption {
- return func(m *PaymentOrderMutation) {
- m.oldValue = func(context.Context) (*PaymentOrder, error) {
- return node, nil
- }
- m.id = &node.ID
- }
-}
-
-// Client returns a new `ent.Client` from the mutation. If the mutation was
-// executed in a transaction (ent.Tx), a transactional client is returned.
-func (m PaymentOrderMutation) Client() *Client {
- client := &Client{config: m.config}
- client.init()
- return client
-}
-
-// Tx returns an `ent.Tx` for mutations that were executed in transactions;
-// it returns an error otherwise.
-func (m PaymentOrderMutation) Tx() (*Tx, error) {
- if _, ok := m.driver.(*txDriver); !ok {
- return nil, errors.New("ent: mutation is not running in a transaction")
- }
- tx := &Tx{config: m.config}
- tx.init()
- return tx, nil
-}
-
-// ID returns the ID value in the mutation. Note that the ID is only available
-// if it was provided to the builder or after it was returned from the database.
-func (m *PaymentOrderMutation) ID() (id int64, exists bool) {
- if m.id == nil {
- return
- }
- return *m.id, true
+// ErrorReasonCleared returns if the "error_reason" field was cleared in this mutation.
+func (m *IdempotencyRecordMutation) ErrorReasonCleared() bool {
+ _, ok := m.clearedFields[idempotencyrecord.FieldErrorReason]
+ return ok
}
-// IDs queries the database and returns the entity ids that match the mutation's predicate.
-// That means, if the mutation is applied within a transaction with an isolation level such
-// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
-// or updated by the mutation.
-func (m *PaymentOrderMutation) IDs(ctx context.Context) ([]int64, error) {
- switch {
- case m.op.Is(OpUpdateOne | OpDeleteOne):
- id, exists := m.ID()
- if exists {
- return []int64{id}, nil
- }
- fallthrough
- case m.op.Is(OpUpdate | OpDelete):
- return m.Client().PaymentOrder.Query().Where(m.predicates...).IDs(ctx)
- default:
- return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
- }
+// ResetErrorReason resets all changes to the "error_reason" field.
+func (m *IdempotencyRecordMutation) ResetErrorReason() {
+ m.error_reason = nil
+ delete(m.clearedFields, idempotencyrecord.FieldErrorReason)
}
-// SetUserID sets the "user_id" field.
-func (m *PaymentOrderMutation) SetUserID(i int64) {
- m.user = &i
+// SetLockedUntil sets the "locked_until" field.
+func (m *IdempotencyRecordMutation) SetLockedUntil(t time.Time) {
+ m.locked_until = &t
}
-// UserID returns the value of the "user_id" field in the mutation.
-func (m *PaymentOrderMutation) UserID() (r int64, exists bool) {
- v := m.user
+// LockedUntil returns the value of the "locked_until" field in the mutation.
+func (m *IdempotencyRecordMutation) LockedUntil() (r time.Time, exists bool) {
+ v := m.locked_until
if v == nil {
return
}
return *v, true
}
-// OldUserID returns the old "user_id" field's value of the PaymentOrder entity.
-// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// OldLockedUntil returns the old "locked_until" field's value of the IdempotencyRecord entity.
+// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentOrderMutation) OldUserID(ctx context.Context) (v int64, err error) {
+func (m *IdempotencyRecordMutation) OldLockedUntil(ctx context.Context) (v *time.Time, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldUserID is only allowed on UpdateOne operations")
+ return v, errors.New("OldLockedUntil is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldUserID requires an ID field in the mutation")
+ return v, errors.New("OldLockedUntil requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldUserID: %w", err)
+ return v, fmt.Errorf("querying old value for OldLockedUntil: %w", err)
}
- return oldValue.UserID, nil
+ return oldValue.LockedUntil, nil
}
-// ResetUserID resets all changes to the "user_id" field.
-func (m *PaymentOrderMutation) ResetUserID() {
- m.user = nil
+// ClearLockedUntil clears the value of the "locked_until" field.
+func (m *IdempotencyRecordMutation) ClearLockedUntil() {
+ m.locked_until = nil
+ m.clearedFields[idempotencyrecord.FieldLockedUntil] = struct{}{}
}
-// SetUserEmail sets the "user_email" field.
-func (m *PaymentOrderMutation) SetUserEmail(s string) {
- m.user_email = &s
+// LockedUntilCleared returns if the "locked_until" field was cleared in this mutation.
+func (m *IdempotencyRecordMutation) LockedUntilCleared() bool {
+ _, ok := m.clearedFields[idempotencyrecord.FieldLockedUntil]
+ return ok
}
-// UserEmail returns the value of the "user_email" field in the mutation.
-func (m *PaymentOrderMutation) UserEmail() (r string, exists bool) {
- v := m.user_email
+// ResetLockedUntil resets all changes to the "locked_until" field.
+func (m *IdempotencyRecordMutation) ResetLockedUntil() {
+ m.locked_until = nil
+ delete(m.clearedFields, idempotencyrecord.FieldLockedUntil)
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (m *IdempotencyRecordMutation) SetExpiresAt(t time.Time) {
+ m.expires_at = &t
+}
+
+// ExpiresAt returns the value of the "expires_at" field in the mutation.
+func (m *IdempotencyRecordMutation) ExpiresAt() (r time.Time, exists bool) {
+ v := m.expires_at
if v == nil {
return
}
return *v, true
}
-// OldUserEmail returns the old "user_email" field's value of the PaymentOrder entity.
-// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// OldExpiresAt returns the old "expires_at" field's value of the IdempotencyRecord entity.
+// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentOrderMutation) OldUserEmail(ctx context.Context) (v string, err error) {
+func (m *IdempotencyRecordMutation) OldExpiresAt(ctx context.Context) (v time.Time, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldUserEmail is only allowed on UpdateOne operations")
+ return v, errors.New("OldExpiresAt is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldUserEmail requires an ID field in the mutation")
+ return v, errors.New("OldExpiresAt requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldUserEmail: %w", err)
+ return v, fmt.Errorf("querying old value for OldExpiresAt: %w", err)
}
- return oldValue.UserEmail, nil
+ return oldValue.ExpiresAt, nil
}
-// ResetUserEmail resets all changes to the "user_email" field.
-func (m *PaymentOrderMutation) ResetUserEmail() {
- m.user_email = nil
+// ResetExpiresAt resets all changes to the "expires_at" field.
+func (m *IdempotencyRecordMutation) ResetExpiresAt() {
+ m.expires_at = nil
}
-// SetUserName sets the "user_name" field.
-func (m *PaymentOrderMutation) SetUserName(s string) {
- m.user_name = &s
+// Where appends a list predicates to the IdempotencyRecordMutation builder.
+func (m *IdempotencyRecordMutation) Where(ps ...predicate.IdempotencyRecord) {
+ m.predicates = append(m.predicates, ps...)
}
-// UserName returns the value of the "user_name" field in the mutation.
-func (m *PaymentOrderMutation) UserName() (r string, exists bool) {
- v := m.user_name
- if v == nil {
- return
- }
- return *v, true
-}
-
-// OldUserName returns the old "user_name" field's value of the PaymentOrder entity.
-// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentOrderMutation) OldUserName(ctx context.Context) (v string, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldUserName is only allowed on UpdateOne operations")
- }
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldUserName requires an ID field in the mutation")
- }
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldUserName: %w", err)
+// WhereP appends storage-level predicates to the IdempotencyRecordMutation builder. Using this method,
+// users can use type-assertion to append predicates that do not depend on any generated package.
+func (m *IdempotencyRecordMutation) WhereP(ps ...func(*sql.Selector)) {
+ p := make([]predicate.IdempotencyRecord, len(ps))
+ for i := range ps {
+ p[i] = ps[i]
}
- return oldValue.UserName, nil
+ m.Where(p...)
}
-// ResetUserName resets all changes to the "user_name" field.
-func (m *PaymentOrderMutation) ResetUserName() {
- m.user_name = nil
+// Op returns the operation name.
+func (m *IdempotencyRecordMutation) Op() Op {
+ return m.op
}
-// SetUserNotes sets the "user_notes" field.
-func (m *PaymentOrderMutation) SetUserNotes(s string) {
- m.user_notes = &s
+// SetOp allows setting the mutation operation.
+func (m *IdempotencyRecordMutation) SetOp(op Op) {
+ m.op = op
}
-// UserNotes returns the value of the "user_notes" field in the mutation.
-func (m *PaymentOrderMutation) UserNotes() (r string, exists bool) {
- v := m.user_notes
- if v == nil {
- return
- }
- return *v, true
+// Type returns the node type of this mutation (IdempotencyRecord).
+func (m *IdempotencyRecordMutation) Type() string {
+ return m.typ
}
-// OldUserNotes returns the old "user_notes" field's value of the PaymentOrder entity.
-// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentOrderMutation) OldUserNotes(ctx context.Context) (v *string, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldUserNotes is only allowed on UpdateOne operations")
- }
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldUserNotes requires an ID field in the mutation")
- }
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldUserNotes: %w", err)
+// Fields returns all fields that were changed during this mutation. Note that in
+// order to get all numeric fields that were incremented/decremented, call
+// AddedFields().
+func (m *IdempotencyRecordMutation) Fields() []string {
+ fields := make([]string, 0, 11)
+ if m.created_at != nil {
+ fields = append(fields, idempotencyrecord.FieldCreatedAt)
}
- return oldValue.UserNotes, nil
-}
-
-// ClearUserNotes clears the value of the "user_notes" field.
-func (m *PaymentOrderMutation) ClearUserNotes() {
- m.user_notes = nil
- m.clearedFields[paymentorder.FieldUserNotes] = struct{}{}
-}
-
-// UserNotesCleared returns if the "user_notes" field was cleared in this mutation.
-func (m *PaymentOrderMutation) UserNotesCleared() bool {
- _, ok := m.clearedFields[paymentorder.FieldUserNotes]
- return ok
-}
-
-// ResetUserNotes resets all changes to the "user_notes" field.
-func (m *PaymentOrderMutation) ResetUserNotes() {
- m.user_notes = nil
- delete(m.clearedFields, paymentorder.FieldUserNotes)
-}
-
-// SetAmount sets the "amount" field.
-func (m *PaymentOrderMutation) SetAmount(f float64) {
- m.amount = &f
- m.addamount = nil
-}
-
-// Amount returns the value of the "amount" field in the mutation.
-func (m *PaymentOrderMutation) Amount() (r float64, exists bool) {
- v := m.amount
- if v == nil {
- return
+ if m.updated_at != nil {
+ fields = append(fields, idempotencyrecord.FieldUpdatedAt)
}
- return *v, true
-}
-
-// OldAmount returns the old "amount" field's value of the PaymentOrder entity.
-// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentOrderMutation) OldAmount(ctx context.Context) (v float64, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldAmount is only allowed on UpdateOne operations")
+ if m.scope != nil {
+ fields = append(fields, idempotencyrecord.FieldScope)
}
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldAmount requires an ID field in the mutation")
+ if m.idempotency_key_hash != nil {
+ fields = append(fields, idempotencyrecord.FieldIdempotencyKeyHash)
}
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldAmount: %w", err)
+ if m.request_fingerprint != nil {
+ fields = append(fields, idempotencyrecord.FieldRequestFingerprint)
}
- return oldValue.Amount, nil
-}
-
-// AddAmount adds f to the "amount" field.
-func (m *PaymentOrderMutation) AddAmount(f float64) {
- if m.addamount != nil {
- *m.addamount += f
- } else {
- m.addamount = &f
+ if m.status != nil {
+ fields = append(fields, idempotencyrecord.FieldStatus)
}
-}
-
-// AddedAmount returns the value that was added to the "amount" field in this mutation.
-func (m *PaymentOrderMutation) AddedAmount() (r float64, exists bool) {
- v := m.addamount
- if v == nil {
- return
+ if m.response_status != nil {
+ fields = append(fields, idempotencyrecord.FieldResponseStatus)
}
- return *v, true
-}
-
-// ResetAmount resets all changes to the "amount" field.
-func (m *PaymentOrderMutation) ResetAmount() {
- m.amount = nil
- m.addamount = nil
-}
-
-// SetPayAmount sets the "pay_amount" field.
-func (m *PaymentOrderMutation) SetPayAmount(f float64) {
- m.pay_amount = &f
- m.addpay_amount = nil
-}
-
-// PayAmount returns the value of the "pay_amount" field in the mutation.
-func (m *PaymentOrderMutation) PayAmount() (r float64, exists bool) {
- v := m.pay_amount
- if v == nil {
- return
+ if m.response_body != nil {
+ fields = append(fields, idempotencyrecord.FieldResponseBody)
}
- return *v, true
-}
-
-// OldPayAmount returns the old "pay_amount" field's value of the PaymentOrder entity.
-// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentOrderMutation) OldPayAmount(ctx context.Context) (v float64, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldPayAmount is only allowed on UpdateOne operations")
+ if m.error_reason != nil {
+ fields = append(fields, idempotencyrecord.FieldErrorReason)
}
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldPayAmount requires an ID field in the mutation")
+ if m.locked_until != nil {
+ fields = append(fields, idempotencyrecord.FieldLockedUntil)
}
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldPayAmount: %w", err)
+ if m.expires_at != nil {
+ fields = append(fields, idempotencyrecord.FieldExpiresAt)
}
- return oldValue.PayAmount, nil
+ return fields
}
-// AddPayAmount adds f to the "pay_amount" field.
-func (m *PaymentOrderMutation) AddPayAmount(f float64) {
- if m.addpay_amount != nil {
- *m.addpay_amount += f
- } else {
- m.addpay_amount = &f
+// Field returns the value of a field with the given name. The second boolean
+// return value indicates that this field was not set, or was not defined in the
+// schema.
+func (m *IdempotencyRecordMutation) Field(name string) (ent.Value, bool) {
+ switch name {
+ case idempotencyrecord.FieldCreatedAt:
+ return m.CreatedAt()
+ case idempotencyrecord.FieldUpdatedAt:
+ return m.UpdatedAt()
+ case idempotencyrecord.FieldScope:
+ return m.Scope()
+ case idempotencyrecord.FieldIdempotencyKeyHash:
+ return m.IdempotencyKeyHash()
+ case idempotencyrecord.FieldRequestFingerprint:
+ return m.RequestFingerprint()
+ case idempotencyrecord.FieldStatus:
+ return m.Status()
+ case idempotencyrecord.FieldResponseStatus:
+ return m.ResponseStatus()
+ case idempotencyrecord.FieldResponseBody:
+ return m.ResponseBody()
+ case idempotencyrecord.FieldErrorReason:
+ return m.ErrorReason()
+ case idempotencyrecord.FieldLockedUntil:
+ return m.LockedUntil()
+ case idempotencyrecord.FieldExpiresAt:
+ return m.ExpiresAt()
}
+ return nil, false
}
-// AddedPayAmount returns the value that was added to the "pay_amount" field in this mutation.
-func (m *PaymentOrderMutation) AddedPayAmount() (r float64, exists bool) {
- v := m.addpay_amount
- if v == nil {
- return
+// OldField returns the old value of the field from the database. An error is
+// returned if the mutation operation is not UpdateOne, or the query to the
+// database failed.
+func (m *IdempotencyRecordMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
+ switch name {
+ case idempotencyrecord.FieldCreatedAt:
+ return m.OldCreatedAt(ctx)
+ case idempotencyrecord.FieldUpdatedAt:
+ return m.OldUpdatedAt(ctx)
+ case idempotencyrecord.FieldScope:
+ return m.OldScope(ctx)
+ case idempotencyrecord.FieldIdempotencyKeyHash:
+ return m.OldIdempotencyKeyHash(ctx)
+ case idempotencyrecord.FieldRequestFingerprint:
+ return m.OldRequestFingerprint(ctx)
+ case idempotencyrecord.FieldStatus:
+ return m.OldStatus(ctx)
+ case idempotencyrecord.FieldResponseStatus:
+ return m.OldResponseStatus(ctx)
+ case idempotencyrecord.FieldResponseBody:
+ return m.OldResponseBody(ctx)
+ case idempotencyrecord.FieldErrorReason:
+ return m.OldErrorReason(ctx)
+ case idempotencyrecord.FieldLockedUntil:
+ return m.OldLockedUntil(ctx)
+ case idempotencyrecord.FieldExpiresAt:
+ return m.OldExpiresAt(ctx)
}
- return *v, true
-}
-
-// ResetPayAmount resets all changes to the "pay_amount" field.
-func (m *PaymentOrderMutation) ResetPayAmount() {
- m.pay_amount = nil
- m.addpay_amount = nil
+ return nil, fmt.Errorf("unknown IdempotencyRecord field %s", name)
}
-// SetFeeRate sets the "fee_rate" field.
-func (m *PaymentOrderMutation) SetFeeRate(f float64) {
- m.fee_rate = &f
- m.addfee_rate = nil
-}
-
-// FeeRate returns the value of the "fee_rate" field in the mutation.
-func (m *PaymentOrderMutation) FeeRate() (r float64, exists bool) {
- v := m.fee_rate
- if v == nil {
- return
+// SetField sets the value of a field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *IdempotencyRecordMutation) SetField(name string, value ent.Value) error {
+ switch name {
+ case idempotencyrecord.FieldCreatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCreatedAt(v)
+ return nil
+ case idempotencyrecord.FieldUpdatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUpdatedAt(v)
+ return nil
+ case idempotencyrecord.FieldScope:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetScope(v)
+ return nil
+ case idempotencyrecord.FieldIdempotencyKeyHash:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetIdempotencyKeyHash(v)
+ return nil
+ case idempotencyrecord.FieldRequestFingerprint:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetRequestFingerprint(v)
+ return nil
+ case idempotencyrecord.FieldStatus:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetStatus(v)
+ return nil
+ case idempotencyrecord.FieldResponseStatus:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetResponseStatus(v)
+ return nil
+ case idempotencyrecord.FieldResponseBody:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetResponseBody(v)
+ return nil
+ case idempotencyrecord.FieldErrorReason:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetErrorReason(v)
+ return nil
+ case idempotencyrecord.FieldLockedUntil:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetLockedUntil(v)
+ return nil
+ case idempotencyrecord.FieldExpiresAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetExpiresAt(v)
+ return nil
}
- return *v, true
+ return fmt.Errorf("unknown IdempotencyRecord field %s", name)
}
-// OldFeeRate returns the old "fee_rate" field's value of the PaymentOrder entity.
-// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentOrderMutation) OldFeeRate(ctx context.Context) (v float64, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldFeeRate is only allowed on UpdateOne operations")
- }
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldFeeRate requires an ID field in the mutation")
- }
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldFeeRate: %w", err)
+// AddedFields returns all numeric fields that were incremented/decremented during
+// this mutation.
+func (m *IdempotencyRecordMutation) AddedFields() []string {
+ var fields []string
+ if m.addresponse_status != nil {
+ fields = append(fields, idempotencyrecord.FieldResponseStatus)
}
- return oldValue.FeeRate, nil
+ return fields
}
-// AddFeeRate adds f to the "fee_rate" field.
-func (m *PaymentOrderMutation) AddFeeRate(f float64) {
- if m.addfee_rate != nil {
- *m.addfee_rate += f
- } else {
- m.addfee_rate = &f
+// AddedField returns the numeric value that was incremented/decremented on a field
+// with the given name. The second boolean return value indicates that this field
+// was not set, or was not defined in the schema.
+func (m *IdempotencyRecordMutation) AddedField(name string) (ent.Value, bool) {
+ switch name {
+ case idempotencyrecord.FieldResponseStatus:
+ return m.AddedResponseStatus()
}
+ return nil, false
}
-// AddedFeeRate returns the value that was added to the "fee_rate" field in this mutation.
-func (m *PaymentOrderMutation) AddedFeeRate() (r float64, exists bool) {
- v := m.addfee_rate
- if v == nil {
- return
+// AddField adds the value to the field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *IdempotencyRecordMutation) AddField(name string, value ent.Value) error {
+ switch name {
+ case idempotencyrecord.FieldResponseStatus:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddResponseStatus(v)
+ return nil
}
- return *v, true
-}
-
-// ResetFeeRate resets all changes to the "fee_rate" field.
-func (m *PaymentOrderMutation) ResetFeeRate() {
- m.fee_rate = nil
- m.addfee_rate = nil
-}
-
-// SetRechargeCode sets the "recharge_code" field.
-func (m *PaymentOrderMutation) SetRechargeCode(s string) {
- m.recharge_code = &s
+ return fmt.Errorf("unknown IdempotencyRecord numeric field %s", name)
}
-// RechargeCode returns the value of the "recharge_code" field in the mutation.
-func (m *PaymentOrderMutation) RechargeCode() (r string, exists bool) {
- v := m.recharge_code
- if v == nil {
- return
+// ClearedFields returns all nullable fields that were cleared during this
+// mutation.
+func (m *IdempotencyRecordMutation) ClearedFields() []string {
+ var fields []string
+ if m.FieldCleared(idempotencyrecord.FieldResponseStatus) {
+ fields = append(fields, idempotencyrecord.FieldResponseStatus)
}
- return *v, true
-}
-
-// OldRechargeCode returns the old "recharge_code" field's value of the PaymentOrder entity.
-// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentOrderMutation) OldRechargeCode(ctx context.Context) (v string, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldRechargeCode is only allowed on UpdateOne operations")
+ if m.FieldCleared(idempotencyrecord.FieldResponseBody) {
+ fields = append(fields, idempotencyrecord.FieldResponseBody)
}
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldRechargeCode requires an ID field in the mutation")
+ if m.FieldCleared(idempotencyrecord.FieldErrorReason) {
+ fields = append(fields, idempotencyrecord.FieldErrorReason)
}
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldRechargeCode: %w", err)
+ if m.FieldCleared(idempotencyrecord.FieldLockedUntil) {
+ fields = append(fields, idempotencyrecord.FieldLockedUntil)
}
- return oldValue.RechargeCode, nil
-}
-
-// ResetRechargeCode resets all changes to the "recharge_code" field.
-func (m *PaymentOrderMutation) ResetRechargeCode() {
- m.recharge_code = nil
-}
-
-// SetOutTradeNo sets the "out_trade_no" field.
-func (m *PaymentOrderMutation) SetOutTradeNo(s string) {
- m.out_trade_no = &s
+ return fields
}
-// OutTradeNo returns the value of the "out_trade_no" field in the mutation.
-func (m *PaymentOrderMutation) OutTradeNo() (r string, exists bool) {
- v := m.out_trade_no
- if v == nil {
- return
- }
- return *v, true
+// FieldCleared returns a boolean indicating if a field with the given name was
+// cleared in this mutation.
+func (m *IdempotencyRecordMutation) FieldCleared(name string) bool {
+ _, ok := m.clearedFields[name]
+ return ok
}
-// OldOutTradeNo returns the old "out_trade_no" field's value of the PaymentOrder entity.
-// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentOrderMutation) OldOutTradeNo(ctx context.Context) (v string, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldOutTradeNo is only allowed on UpdateOne operations")
- }
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldOutTradeNo requires an ID field in the mutation")
- }
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldOutTradeNo: %w", err)
+// ClearField clears the value of the field with the given name. It returns an
+// error if the field is not defined in the schema.
+func (m *IdempotencyRecordMutation) ClearField(name string) error {
+ switch name {
+ case idempotencyrecord.FieldResponseStatus:
+ m.ClearResponseStatus()
+ return nil
+ case idempotencyrecord.FieldResponseBody:
+ m.ClearResponseBody()
+ return nil
+ case idempotencyrecord.FieldErrorReason:
+ m.ClearErrorReason()
+ return nil
+ case idempotencyrecord.FieldLockedUntil:
+ m.ClearLockedUntil()
+ return nil
}
- return oldValue.OutTradeNo, nil
-}
-
-// ResetOutTradeNo resets all changes to the "out_trade_no" field.
-func (m *PaymentOrderMutation) ResetOutTradeNo() {
- m.out_trade_no = nil
-}
-
-// SetPaymentType sets the "payment_type" field.
-func (m *PaymentOrderMutation) SetPaymentType(s string) {
- m.payment_type = &s
+ return fmt.Errorf("unknown IdempotencyRecord nullable field %s", name)
}
-// PaymentType returns the value of the "payment_type" field in the mutation.
-func (m *PaymentOrderMutation) PaymentType() (r string, exists bool) {
- v := m.payment_type
+// ResetField resets all changes in the mutation for the field with the given name.
+// It returns an error if the field is not defined in the schema.
+func (m *IdempotencyRecordMutation) ResetField(name string) error {
+ switch name {
+ case idempotencyrecord.FieldCreatedAt:
+ m.ResetCreatedAt()
+ return nil
+ case idempotencyrecord.FieldUpdatedAt:
+ m.ResetUpdatedAt()
+ return nil
+ case idempotencyrecord.FieldScope:
+ m.ResetScope()
+ return nil
+ case idempotencyrecord.FieldIdempotencyKeyHash:
+ m.ResetIdempotencyKeyHash()
+ return nil
+ case idempotencyrecord.FieldRequestFingerprint:
+ m.ResetRequestFingerprint()
+ return nil
+ case idempotencyrecord.FieldStatus:
+ m.ResetStatus()
+ return nil
+ case idempotencyrecord.FieldResponseStatus:
+ m.ResetResponseStatus()
+ return nil
+ case idempotencyrecord.FieldResponseBody:
+ m.ResetResponseBody()
+ return nil
+ case idempotencyrecord.FieldErrorReason:
+ m.ResetErrorReason()
+ return nil
+ case idempotencyrecord.FieldLockedUntil:
+ m.ResetLockedUntil()
+ return nil
+ case idempotencyrecord.FieldExpiresAt:
+ m.ResetExpiresAt()
+ return nil
+ }
+ return fmt.Errorf("unknown IdempotencyRecord field %s", name)
+}
+
+// AddedEdges returns all edge names that were set/added in this mutation.
+func (m *IdempotencyRecordMutation) AddedEdges() []string {
+ edges := make([]string, 0, 0)
+ return edges
+}
+
+// AddedIDs returns all IDs (to other nodes) that were added for the given edge
+// name in this mutation.
+func (m *IdempotencyRecordMutation) AddedIDs(name string) []ent.Value {
+ return nil
+}
+
+// RemovedEdges returns all edge names that were removed in this mutation.
+func (m *IdempotencyRecordMutation) RemovedEdges() []string {
+ edges := make([]string, 0, 0)
+ return edges
+}
+
+// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
+// the given name in this mutation.
+func (m *IdempotencyRecordMutation) RemovedIDs(name string) []ent.Value {
+ return nil
+}
+
+// ClearedEdges returns all edge names that were cleared in this mutation.
+func (m *IdempotencyRecordMutation) ClearedEdges() []string {
+ edges := make([]string, 0, 0)
+ return edges
+}
+
+// EdgeCleared returns a boolean which indicates if the edge with the given name
+// was cleared in this mutation.
+func (m *IdempotencyRecordMutation) EdgeCleared(name string) bool {
+ return false
+}
+
+// ClearEdge clears the value of the edge with the given name. It returns an error
+// if that edge is not defined in the schema.
+func (m *IdempotencyRecordMutation) ClearEdge(name string) error {
+ return fmt.Errorf("unknown IdempotencyRecord unique edge %s", name)
+}
+
+// ResetEdge resets all changes to the edge with the given name in this mutation.
+// It returns an error if the edge is not defined in the schema.
+func (m *IdempotencyRecordMutation) ResetEdge(name string) error {
+ return fmt.Errorf("unknown IdempotencyRecord edge %s", name)
+}
+
+// IdentityAdoptionDecisionMutation represents an operation that mutates the IdentityAdoptionDecision nodes in the graph.
+type IdentityAdoptionDecisionMutation struct {
+ config
+ op Op
+ typ string
+ id *int64
+ created_at *time.Time
+ updated_at *time.Time
+ adopt_display_name *bool
+ adopt_avatar *bool
+ decided_at *time.Time
+ clearedFields map[string]struct{}
+ pending_auth_session *int64
+ clearedpending_auth_session bool
+ identity *int64
+ clearedidentity bool
+ done bool
+ oldValue func(context.Context) (*IdentityAdoptionDecision, error)
+ predicates []predicate.IdentityAdoptionDecision
+}
+
+var _ ent.Mutation = (*IdentityAdoptionDecisionMutation)(nil)
+
+// identityadoptiondecisionOption allows management of the mutation configuration using functional options.
+type identityadoptiondecisionOption func(*IdentityAdoptionDecisionMutation)
+
+// newIdentityAdoptionDecisionMutation creates new mutation for the IdentityAdoptionDecision entity.
+func newIdentityAdoptionDecisionMutation(c config, op Op, opts ...identityadoptiondecisionOption) *IdentityAdoptionDecisionMutation {
+ m := &IdentityAdoptionDecisionMutation{
+ config: c,
+ op: op,
+ typ: TypeIdentityAdoptionDecision,
+ clearedFields: make(map[string]struct{}),
+ }
+ for _, opt := range opts {
+ opt(m)
+ }
+ return m
+}
+
+// withIdentityAdoptionDecisionID sets the ID field of the mutation.
+func withIdentityAdoptionDecisionID(id int64) identityadoptiondecisionOption {
+ return func(m *IdentityAdoptionDecisionMutation) {
+ var (
+ err error
+ once sync.Once
+ value *IdentityAdoptionDecision
+ )
+ m.oldValue = func(ctx context.Context) (*IdentityAdoptionDecision, error) {
+ once.Do(func() {
+ if m.done {
+ err = errors.New("querying old values post mutation is not allowed")
+ } else {
+ value, err = m.Client().IdentityAdoptionDecision.Get(ctx, id)
+ }
+ })
+ return value, err
+ }
+ m.id = &id
+ }
+}
+
+// withIdentityAdoptionDecision sets the old IdentityAdoptionDecision of the mutation.
+func withIdentityAdoptionDecision(node *IdentityAdoptionDecision) identityadoptiondecisionOption {
+ return func(m *IdentityAdoptionDecisionMutation) {
+ m.oldValue = func(context.Context) (*IdentityAdoptionDecision, error) {
+ return node, nil
+ }
+ m.id = &node.ID
+ }
+}
+
+// Client returns a new `ent.Client` from the mutation. If the mutation was
+// executed in a transaction (ent.Tx), a transactional client is returned.
+func (m IdentityAdoptionDecisionMutation) Client() *Client {
+ client := &Client{config: m.config}
+ client.init()
+ return client
+}
+
+// Tx returns an `ent.Tx` for mutations that were executed in transactions;
+// it returns an error otherwise.
+func (m IdentityAdoptionDecisionMutation) Tx() (*Tx, error) {
+ if _, ok := m.driver.(*txDriver); !ok {
+ return nil, errors.New("ent: mutation is not running in a transaction")
+ }
+ tx := &Tx{config: m.config}
+ tx.init()
+ return tx, nil
+}
+
+// ID returns the ID value in the mutation. Note that the ID is only available
+// if it was provided to the builder or after it was returned from the database.
+func (m *IdentityAdoptionDecisionMutation) ID() (id int64, exists bool) {
+ if m.id == nil {
+ return
+ }
+ return *m.id, true
+}
+
+// IDs queries the database and returns the entity ids that match the mutation's predicate.
+// That means, if the mutation is applied within a transaction with an isolation level such
+// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
+// or updated by the mutation.
+func (m *IdentityAdoptionDecisionMutation) IDs(ctx context.Context) ([]int64, error) {
+ switch {
+ case m.op.Is(OpUpdateOne | OpDeleteOne):
+ id, exists := m.ID()
+ if exists {
+ return []int64{id}, nil
+ }
+ fallthrough
+ case m.op.Is(OpUpdate | OpDelete):
+ return m.Client().IdentityAdoptionDecision.Query().Where(m.predicates...).IDs(ctx)
+ default:
+ return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
+ }
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (m *IdentityAdoptionDecisionMutation) SetCreatedAt(t time.Time) {
+ m.created_at = &t
+}
+
+// CreatedAt returns the value of the "created_at" field in the mutation.
+func (m *IdentityAdoptionDecisionMutation) CreatedAt() (r time.Time, exists bool) {
+ v := m.created_at
if v == nil {
return
}
return *v, true
}
-// OldPaymentType returns the old "payment_type" field's value of the PaymentOrder entity.
-// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// OldCreatedAt returns the old "created_at" field's value of the IdentityAdoptionDecision entity.
+// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentOrderMutation) OldPaymentType(ctx context.Context) (v string, err error) {
+func (m *IdentityAdoptionDecisionMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldPaymentType is only allowed on UpdateOne operations")
+ return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldPaymentType requires an ID field in the mutation")
+ return v, errors.New("OldCreatedAt requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldPaymentType: %w", err)
+ return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err)
}
- return oldValue.PaymentType, nil
+ return oldValue.CreatedAt, nil
}
-// ResetPaymentType resets all changes to the "payment_type" field.
-func (m *PaymentOrderMutation) ResetPaymentType() {
- m.payment_type = nil
+// ResetCreatedAt resets all changes to the "created_at" field.
+func (m *IdentityAdoptionDecisionMutation) ResetCreatedAt() {
+ m.created_at = nil
}
-// SetPaymentTradeNo sets the "payment_trade_no" field.
-func (m *PaymentOrderMutation) SetPaymentTradeNo(s string) {
- m.payment_trade_no = &s
+// SetUpdatedAt sets the "updated_at" field.
+func (m *IdentityAdoptionDecisionMutation) SetUpdatedAt(t time.Time) {
+ m.updated_at = &t
}
-// PaymentTradeNo returns the value of the "payment_trade_no" field in the mutation.
-func (m *PaymentOrderMutation) PaymentTradeNo() (r string, exists bool) {
- v := m.payment_trade_no
+// UpdatedAt returns the value of the "updated_at" field in the mutation.
+func (m *IdentityAdoptionDecisionMutation) UpdatedAt() (r time.Time, exists bool) {
+ v := m.updated_at
if v == nil {
return
}
return *v, true
}
-// OldPaymentTradeNo returns the old "payment_trade_no" field's value of the PaymentOrder entity.
-// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// OldUpdatedAt returns the old "updated_at" field's value of the IdentityAdoptionDecision entity.
+// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentOrderMutation) OldPaymentTradeNo(ctx context.Context) (v string, err error) {
+func (m *IdentityAdoptionDecisionMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldPaymentTradeNo is only allowed on UpdateOne operations")
+ return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldPaymentTradeNo requires an ID field in the mutation")
+ return v, errors.New("OldUpdatedAt requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldPaymentTradeNo: %w", err)
+ return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err)
}
- return oldValue.PaymentTradeNo, nil
+ return oldValue.UpdatedAt, nil
}
-// ResetPaymentTradeNo resets all changes to the "payment_trade_no" field.
-func (m *PaymentOrderMutation) ResetPaymentTradeNo() {
- m.payment_trade_no = nil
+// ResetUpdatedAt resets all changes to the "updated_at" field.
+func (m *IdentityAdoptionDecisionMutation) ResetUpdatedAt() {
+ m.updated_at = nil
}
-// SetPayURL sets the "pay_url" field.
-func (m *PaymentOrderMutation) SetPayURL(s string) {
- m.pay_url = &s
+// SetPendingAuthSessionID sets the "pending_auth_session_id" field.
+func (m *IdentityAdoptionDecisionMutation) SetPendingAuthSessionID(i int64) {
+ m.pending_auth_session = &i
}
-// PayURL returns the value of the "pay_url" field in the mutation.
-func (m *PaymentOrderMutation) PayURL() (r string, exists bool) {
- v := m.pay_url
+// PendingAuthSessionID returns the value of the "pending_auth_session_id" field in the mutation.
+func (m *IdentityAdoptionDecisionMutation) PendingAuthSessionID() (r int64, exists bool) {
+ v := m.pending_auth_session
if v == nil {
return
}
return *v, true
}
-// OldPayURL returns the old "pay_url" field's value of the PaymentOrder entity.
-// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// OldPendingAuthSessionID returns the old "pending_auth_session_id" field's value of the IdentityAdoptionDecision entity.
+// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentOrderMutation) OldPayURL(ctx context.Context) (v *string, err error) {
+func (m *IdentityAdoptionDecisionMutation) OldPendingAuthSessionID(ctx context.Context) (v int64, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldPayURL is only allowed on UpdateOne operations")
+ return v, errors.New("OldPendingAuthSessionID is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldPayURL requires an ID field in the mutation")
+ return v, errors.New("OldPendingAuthSessionID requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldPayURL: %w", err)
+ return v, fmt.Errorf("querying old value for OldPendingAuthSessionID: %w", err)
}
- return oldValue.PayURL, nil
+ return oldValue.PendingAuthSessionID, nil
}
-// ClearPayURL clears the value of the "pay_url" field.
-func (m *PaymentOrderMutation) ClearPayURL() {
- m.pay_url = nil
- m.clearedFields[paymentorder.FieldPayURL] = struct{}{}
+// ResetPendingAuthSessionID resets all changes to the "pending_auth_session_id" field.
+func (m *IdentityAdoptionDecisionMutation) ResetPendingAuthSessionID() {
+ m.pending_auth_session = nil
}
-// PayURLCleared returns if the "pay_url" field was cleared in this mutation.
-func (m *PaymentOrderMutation) PayURLCleared() bool {
- _, ok := m.clearedFields[paymentorder.FieldPayURL]
- return ok
+// SetIdentityID sets the "identity_id" field.
+func (m *IdentityAdoptionDecisionMutation) SetIdentityID(i int64) {
+ m.identity = &i
}
-// ResetPayURL resets all changes to the "pay_url" field.
-func (m *PaymentOrderMutation) ResetPayURL() {
- m.pay_url = nil
- delete(m.clearedFields, paymentorder.FieldPayURL)
+// IdentityID returns the value of the "identity_id" field in the mutation.
+func (m *IdentityAdoptionDecisionMutation) IdentityID() (r int64, exists bool) {
+ v := m.identity
+ if v == nil {
+ return
+ }
+ return *v, true
}
-// SetQrCode sets the "qr_code" field.
-func (m *PaymentOrderMutation) SetQrCode(s string) {
- m.qr_code = &s
-}
-
-// QrCode returns the value of the "qr_code" field in the mutation.
-func (m *PaymentOrderMutation) QrCode() (r string, exists bool) {
- v := m.qr_code
- if v == nil {
- return
- }
- return *v, true
-}
-
-// OldQrCode returns the old "qr_code" field's value of the PaymentOrder entity.
-// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// OldIdentityID returns the old "identity_id" field's value of the IdentityAdoptionDecision entity.
+// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentOrderMutation) OldQrCode(ctx context.Context) (v *string, err error) {
+func (m *IdentityAdoptionDecisionMutation) OldIdentityID(ctx context.Context) (v *int64, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldQrCode is only allowed on UpdateOne operations")
+ return v, errors.New("OldIdentityID is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldQrCode requires an ID field in the mutation")
+ return v, errors.New("OldIdentityID requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldQrCode: %w", err)
+ return v, fmt.Errorf("querying old value for OldIdentityID: %w", err)
}
- return oldValue.QrCode, nil
+ return oldValue.IdentityID, nil
}
-// ClearQrCode clears the value of the "qr_code" field.
-func (m *PaymentOrderMutation) ClearQrCode() {
- m.qr_code = nil
- m.clearedFields[paymentorder.FieldQrCode] = struct{}{}
+// ClearIdentityID clears the value of the "identity_id" field.
+func (m *IdentityAdoptionDecisionMutation) ClearIdentityID() {
+ m.identity = nil
+ m.clearedFields[identityadoptiondecision.FieldIdentityID] = struct{}{}
}
-// QrCodeCleared returns if the "qr_code" field was cleared in this mutation.
-func (m *PaymentOrderMutation) QrCodeCleared() bool {
- _, ok := m.clearedFields[paymentorder.FieldQrCode]
+// IdentityIDCleared returns if the "identity_id" field was cleared in this mutation.
+func (m *IdentityAdoptionDecisionMutation) IdentityIDCleared() bool {
+ _, ok := m.clearedFields[identityadoptiondecision.FieldIdentityID]
return ok
}
-// ResetQrCode resets all changes to the "qr_code" field.
-func (m *PaymentOrderMutation) ResetQrCode() {
- m.qr_code = nil
- delete(m.clearedFields, paymentorder.FieldQrCode)
+// ResetIdentityID resets all changes to the "identity_id" field.
+func (m *IdentityAdoptionDecisionMutation) ResetIdentityID() {
+ m.identity = nil
+ delete(m.clearedFields, identityadoptiondecision.FieldIdentityID)
}
-// SetQrCodeImg sets the "qr_code_img" field.
-func (m *PaymentOrderMutation) SetQrCodeImg(s string) {
- m.qr_code_img = &s
+// SetAdoptDisplayName sets the "adopt_display_name" field.
+func (m *IdentityAdoptionDecisionMutation) SetAdoptDisplayName(b bool) {
+ m.adopt_display_name = &b
}
-// QrCodeImg returns the value of the "qr_code_img" field in the mutation.
-func (m *PaymentOrderMutation) QrCodeImg() (r string, exists bool) {
- v := m.qr_code_img
+// AdoptDisplayName returns the value of the "adopt_display_name" field in the mutation.
+func (m *IdentityAdoptionDecisionMutation) AdoptDisplayName() (r bool, exists bool) {
+ v := m.adopt_display_name
if v == nil {
return
}
return *v, true
}
-// OldQrCodeImg returns the old "qr_code_img" field's value of the PaymentOrder entity.
-// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// OldAdoptDisplayName returns the old "adopt_display_name" field's value of the IdentityAdoptionDecision entity.
+// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentOrderMutation) OldQrCodeImg(ctx context.Context) (v *string, err error) {
+func (m *IdentityAdoptionDecisionMutation) OldAdoptDisplayName(ctx context.Context) (v bool, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldQrCodeImg is only allowed on UpdateOne operations")
+ return v, errors.New("OldAdoptDisplayName is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldQrCodeImg requires an ID field in the mutation")
+ return v, errors.New("OldAdoptDisplayName requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldQrCodeImg: %w", err)
+ return v, fmt.Errorf("querying old value for OldAdoptDisplayName: %w", err)
}
- return oldValue.QrCodeImg, nil
-}
-
-// ClearQrCodeImg clears the value of the "qr_code_img" field.
-func (m *PaymentOrderMutation) ClearQrCodeImg() {
- m.qr_code_img = nil
- m.clearedFields[paymentorder.FieldQrCodeImg] = struct{}{}
-}
-
-// QrCodeImgCleared returns if the "qr_code_img" field was cleared in this mutation.
-func (m *PaymentOrderMutation) QrCodeImgCleared() bool {
- _, ok := m.clearedFields[paymentorder.FieldQrCodeImg]
- return ok
+ return oldValue.AdoptDisplayName, nil
}
-// ResetQrCodeImg resets all changes to the "qr_code_img" field.
-func (m *PaymentOrderMutation) ResetQrCodeImg() {
- m.qr_code_img = nil
- delete(m.clearedFields, paymentorder.FieldQrCodeImg)
+// ResetAdoptDisplayName resets all changes to the "adopt_display_name" field.
+func (m *IdentityAdoptionDecisionMutation) ResetAdoptDisplayName() {
+ m.adopt_display_name = nil
}
-// SetOrderType sets the "order_type" field.
-func (m *PaymentOrderMutation) SetOrderType(s string) {
- m.order_type = &s
+// SetAdoptAvatar sets the "adopt_avatar" field.
+func (m *IdentityAdoptionDecisionMutation) SetAdoptAvatar(b bool) {
+ m.adopt_avatar = &b
}
-// OrderType returns the value of the "order_type" field in the mutation.
-func (m *PaymentOrderMutation) OrderType() (r string, exists bool) {
- v := m.order_type
+// AdoptAvatar returns the value of the "adopt_avatar" field in the mutation.
+func (m *IdentityAdoptionDecisionMutation) AdoptAvatar() (r bool, exists bool) {
+ v := m.adopt_avatar
if v == nil {
return
}
return *v, true
}
-// OldOrderType returns the old "order_type" field's value of the PaymentOrder entity.
-// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// OldAdoptAvatar returns the old "adopt_avatar" field's value of the IdentityAdoptionDecision entity.
+// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentOrderMutation) OldOrderType(ctx context.Context) (v string, err error) {
+func (m *IdentityAdoptionDecisionMutation) OldAdoptAvatar(ctx context.Context) (v bool, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldOrderType is only allowed on UpdateOne operations")
+ return v, errors.New("OldAdoptAvatar is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldOrderType requires an ID field in the mutation")
+ return v, errors.New("OldAdoptAvatar requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldOrderType: %w", err)
+ return v, fmt.Errorf("querying old value for OldAdoptAvatar: %w", err)
}
- return oldValue.OrderType, nil
+ return oldValue.AdoptAvatar, nil
}
-// ResetOrderType resets all changes to the "order_type" field.
-func (m *PaymentOrderMutation) ResetOrderType() {
- m.order_type = nil
+// ResetAdoptAvatar resets all changes to the "adopt_avatar" field.
+func (m *IdentityAdoptionDecisionMutation) ResetAdoptAvatar() {
+ m.adopt_avatar = nil
}
-// SetPlanID sets the "plan_id" field.
-func (m *PaymentOrderMutation) SetPlanID(i int64) {
- m.plan_id = &i
- m.addplan_id = nil
+// SetDecidedAt sets the "decided_at" field.
+func (m *IdentityAdoptionDecisionMutation) SetDecidedAt(t time.Time) {
+ m.decided_at = &t
}
-// PlanID returns the value of the "plan_id" field in the mutation.
-func (m *PaymentOrderMutation) PlanID() (r int64, exists bool) {
- v := m.plan_id
+// DecidedAt returns the value of the "decided_at" field in the mutation.
+func (m *IdentityAdoptionDecisionMutation) DecidedAt() (r time.Time, exists bool) {
+ v := m.decided_at
if v == nil {
return
}
return *v, true
}
-// OldPlanID returns the old "plan_id" field's value of the PaymentOrder entity.
-// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// OldDecidedAt returns the old "decided_at" field's value of the IdentityAdoptionDecision entity.
+// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentOrderMutation) OldPlanID(ctx context.Context) (v *int64, err error) {
+func (m *IdentityAdoptionDecisionMutation) OldDecidedAt(ctx context.Context) (v time.Time, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldPlanID is only allowed on UpdateOne operations")
+ return v, errors.New("OldDecidedAt is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldPlanID requires an ID field in the mutation")
+ return v, errors.New("OldDecidedAt requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldPlanID: %w", err)
+ return v, fmt.Errorf("querying old value for OldDecidedAt: %w", err)
}
- return oldValue.PlanID, nil
+ return oldValue.DecidedAt, nil
}
-// AddPlanID adds i to the "plan_id" field.
-func (m *PaymentOrderMutation) AddPlanID(i int64) {
- if m.addplan_id != nil {
- *m.addplan_id += i
- } else {
- m.addplan_id = &i
- }
+// ResetDecidedAt resets all changes to the "decided_at" field.
+func (m *IdentityAdoptionDecisionMutation) ResetDecidedAt() {
+ m.decided_at = nil
}
-// AddedPlanID returns the value that was added to the "plan_id" field in this mutation.
-func (m *PaymentOrderMutation) AddedPlanID() (r int64, exists bool) {
- v := m.addplan_id
- if v == nil {
- return
- }
- return *v, true
+// ClearPendingAuthSession clears the "pending_auth_session" edge to the PendingAuthSession entity.
+func (m *IdentityAdoptionDecisionMutation) ClearPendingAuthSession() {
+ m.clearedpending_auth_session = true
+ m.clearedFields[identityadoptiondecision.FieldPendingAuthSessionID] = struct{}{}
}
-// ClearPlanID clears the value of the "plan_id" field.
-func (m *PaymentOrderMutation) ClearPlanID() {
- m.plan_id = nil
- m.addplan_id = nil
- m.clearedFields[paymentorder.FieldPlanID] = struct{}{}
+// PendingAuthSessionCleared reports if the "pending_auth_session" edge to the PendingAuthSession entity was cleared.
+func (m *IdentityAdoptionDecisionMutation) PendingAuthSessionCleared() bool {
+ return m.clearedpending_auth_session
}
-// PlanIDCleared returns if the "plan_id" field was cleared in this mutation.
-func (m *PaymentOrderMutation) PlanIDCleared() bool {
- _, ok := m.clearedFields[paymentorder.FieldPlanID]
- return ok
+// PendingAuthSessionIDs returns the "pending_auth_session" edge IDs in the mutation.
+// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use
+// PendingAuthSessionID instead. It exists only for internal usage by the builders.
+func (m *IdentityAdoptionDecisionMutation) PendingAuthSessionIDs() (ids []int64) {
+ if id := m.pending_auth_session; id != nil {
+ ids = append(ids, *id)
+ }
+ return
}
-// ResetPlanID resets all changes to the "plan_id" field.
-func (m *PaymentOrderMutation) ResetPlanID() {
- m.plan_id = nil
- m.addplan_id = nil
- delete(m.clearedFields, paymentorder.FieldPlanID)
+// ResetPendingAuthSession resets all changes to the "pending_auth_session" edge.
+func (m *IdentityAdoptionDecisionMutation) ResetPendingAuthSession() {
+ m.pending_auth_session = nil
+ m.clearedpending_auth_session = false
}
-// SetSubscriptionGroupID sets the "subscription_group_id" field.
-func (m *PaymentOrderMutation) SetSubscriptionGroupID(i int64) {
- m.subscription_group_id = &i
- m.addsubscription_group_id = nil
+// ClearIdentity clears the "identity" edge to the AuthIdentity entity.
+func (m *IdentityAdoptionDecisionMutation) ClearIdentity() {
+ m.clearedidentity = true
+ m.clearedFields[identityadoptiondecision.FieldIdentityID] = struct{}{}
}
-// SubscriptionGroupID returns the value of the "subscription_group_id" field in the mutation.
-func (m *PaymentOrderMutation) SubscriptionGroupID() (r int64, exists bool) {
- v := m.subscription_group_id
- if v == nil {
- return
- }
- return *v, true
+// IdentityCleared reports if the "identity" edge to the AuthIdentity entity was cleared.
+func (m *IdentityAdoptionDecisionMutation) IdentityCleared() bool {
+ return m.IdentityIDCleared() || m.clearedidentity
}
-// OldSubscriptionGroupID returns the old "subscription_group_id" field's value of the PaymentOrder entity.
-// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentOrderMutation) OldSubscriptionGroupID(ctx context.Context) (v *int64, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldSubscriptionGroupID is only allowed on UpdateOne operations")
- }
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldSubscriptionGroupID requires an ID field in the mutation")
- }
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldSubscriptionGroupID: %w", err)
+// IdentityIDs returns the "identity" edge IDs in the mutation.
+// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use
+// IdentityID instead. It exists only for internal usage by the builders.
+func (m *IdentityAdoptionDecisionMutation) IdentityIDs() (ids []int64) {
+ if id := m.identity; id != nil {
+ ids = append(ids, *id)
}
- return oldValue.SubscriptionGroupID, nil
+ return
}
-// AddSubscriptionGroupID adds i to the "subscription_group_id" field.
-func (m *PaymentOrderMutation) AddSubscriptionGroupID(i int64) {
- if m.addsubscription_group_id != nil {
- *m.addsubscription_group_id += i
- } else {
- m.addsubscription_group_id = &i
- }
+// ResetIdentity resets all changes to the "identity" edge.
+func (m *IdentityAdoptionDecisionMutation) ResetIdentity() {
+ m.identity = nil
+ m.clearedidentity = false
}
-// AddedSubscriptionGroupID returns the value that was added to the "subscription_group_id" field in this mutation.
-func (m *PaymentOrderMutation) AddedSubscriptionGroupID() (r int64, exists bool) {
- v := m.addsubscription_group_id
- if v == nil {
- return
- }
- return *v, true
+// Where appends a list predicates to the IdentityAdoptionDecisionMutation builder.
+func (m *IdentityAdoptionDecisionMutation) Where(ps ...predicate.IdentityAdoptionDecision) {
+ m.predicates = append(m.predicates, ps...)
}
-// ClearSubscriptionGroupID clears the value of the "subscription_group_id" field.
-func (m *PaymentOrderMutation) ClearSubscriptionGroupID() {
- m.subscription_group_id = nil
- m.addsubscription_group_id = nil
- m.clearedFields[paymentorder.FieldSubscriptionGroupID] = struct{}{}
+// WhereP appends storage-level predicates to the IdentityAdoptionDecisionMutation builder. Using this method,
+// users can use type-assertion to append predicates that do not depend on any generated package.
+func (m *IdentityAdoptionDecisionMutation) WhereP(ps ...func(*sql.Selector)) {
+ p := make([]predicate.IdentityAdoptionDecision, len(ps))
+ for i := range ps {
+ p[i] = ps[i]
+ }
+ m.Where(p...)
}
-// SubscriptionGroupIDCleared returns if the "subscription_group_id" field was cleared in this mutation.
-func (m *PaymentOrderMutation) SubscriptionGroupIDCleared() bool {
- _, ok := m.clearedFields[paymentorder.FieldSubscriptionGroupID]
- return ok
+// Op returns the operation name.
+func (m *IdentityAdoptionDecisionMutation) Op() Op {
+ return m.op
}
-// ResetSubscriptionGroupID resets all changes to the "subscription_group_id" field.
-func (m *PaymentOrderMutation) ResetSubscriptionGroupID() {
- m.subscription_group_id = nil
- m.addsubscription_group_id = nil
- delete(m.clearedFields, paymentorder.FieldSubscriptionGroupID)
+// SetOp allows setting the mutation operation.
+func (m *IdentityAdoptionDecisionMutation) SetOp(op Op) {
+ m.op = op
}
-// SetSubscriptionDays sets the "subscription_days" field.
-func (m *PaymentOrderMutation) SetSubscriptionDays(i int) {
- m.subscription_days = &i
- m.addsubscription_days = nil
+// Type returns the node type of this mutation (IdentityAdoptionDecision).
+func (m *IdentityAdoptionDecisionMutation) Type() string {
+ return m.typ
}
-// SubscriptionDays returns the value of the "subscription_days" field in the mutation.
-func (m *PaymentOrderMutation) SubscriptionDays() (r int, exists bool) {
- v := m.subscription_days
- if v == nil {
- return
+// Fields returns all fields that were changed during this mutation. Note that in
+// order to get all numeric fields that were incremented/decremented, call
+// AddedFields().
+func (m *IdentityAdoptionDecisionMutation) Fields() []string {
+ fields := make([]string, 0, 7)
+ if m.created_at != nil {
+ fields = append(fields, identityadoptiondecision.FieldCreatedAt)
}
- return *v, true
-}
-
-// OldSubscriptionDays returns the old "subscription_days" field's value of the PaymentOrder entity.
-// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentOrderMutation) OldSubscriptionDays(ctx context.Context) (v *int, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldSubscriptionDays is only allowed on UpdateOne operations")
+ if m.updated_at != nil {
+ fields = append(fields, identityadoptiondecision.FieldUpdatedAt)
}
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldSubscriptionDays requires an ID field in the mutation")
+ if m.pending_auth_session != nil {
+ fields = append(fields, identityadoptiondecision.FieldPendingAuthSessionID)
}
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldSubscriptionDays: %w", err)
+ if m.identity != nil {
+ fields = append(fields, identityadoptiondecision.FieldIdentityID)
}
- return oldValue.SubscriptionDays, nil
-}
-
-// AddSubscriptionDays adds i to the "subscription_days" field.
-func (m *PaymentOrderMutation) AddSubscriptionDays(i int) {
- if m.addsubscription_days != nil {
- *m.addsubscription_days += i
- } else {
- m.addsubscription_days = &i
+ if m.adopt_display_name != nil {
+ fields = append(fields, identityadoptiondecision.FieldAdoptDisplayName)
+ }
+ if m.adopt_avatar != nil {
+ fields = append(fields, identityadoptiondecision.FieldAdoptAvatar)
}
+ if m.decided_at != nil {
+ fields = append(fields, identityadoptiondecision.FieldDecidedAt)
+ }
+ return fields
}
-// AddedSubscriptionDays returns the value that was added to the "subscription_days" field in this mutation.
-func (m *PaymentOrderMutation) AddedSubscriptionDays() (r int, exists bool) {
- v := m.addsubscription_days
- if v == nil {
- return
+// Field returns the value of a field with the given name. The second boolean
+// return value indicates that this field was not set, or was not defined in the
+// schema.
+func (m *IdentityAdoptionDecisionMutation) Field(name string) (ent.Value, bool) {
+ switch name {
+ case identityadoptiondecision.FieldCreatedAt:
+ return m.CreatedAt()
+ case identityadoptiondecision.FieldUpdatedAt:
+ return m.UpdatedAt()
+ case identityadoptiondecision.FieldPendingAuthSessionID:
+ return m.PendingAuthSessionID()
+ case identityadoptiondecision.FieldIdentityID:
+ return m.IdentityID()
+ case identityadoptiondecision.FieldAdoptDisplayName:
+ return m.AdoptDisplayName()
+ case identityadoptiondecision.FieldAdoptAvatar:
+ return m.AdoptAvatar()
+ case identityadoptiondecision.FieldDecidedAt:
+ return m.DecidedAt()
}
- return *v, true
+ return nil, false
}
-// ClearSubscriptionDays clears the value of the "subscription_days" field.
-func (m *PaymentOrderMutation) ClearSubscriptionDays() {
- m.subscription_days = nil
- m.addsubscription_days = nil
- m.clearedFields[paymentorder.FieldSubscriptionDays] = struct{}{}
+// OldField returns the old value of the field from the database. An error is
+// returned if the mutation operation is not UpdateOne, or the query to the
+// database failed.
+func (m *IdentityAdoptionDecisionMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
+ switch name {
+ case identityadoptiondecision.FieldCreatedAt:
+ return m.OldCreatedAt(ctx)
+ case identityadoptiondecision.FieldUpdatedAt:
+ return m.OldUpdatedAt(ctx)
+ case identityadoptiondecision.FieldPendingAuthSessionID:
+ return m.OldPendingAuthSessionID(ctx)
+ case identityadoptiondecision.FieldIdentityID:
+ return m.OldIdentityID(ctx)
+ case identityadoptiondecision.FieldAdoptDisplayName:
+ return m.OldAdoptDisplayName(ctx)
+ case identityadoptiondecision.FieldAdoptAvatar:
+ return m.OldAdoptAvatar(ctx)
+ case identityadoptiondecision.FieldDecidedAt:
+ return m.OldDecidedAt(ctx)
+ }
+ return nil, fmt.Errorf("unknown IdentityAdoptionDecision field %s", name)
}
-// SubscriptionDaysCleared returns if the "subscription_days" field was cleared in this mutation.
-func (m *PaymentOrderMutation) SubscriptionDaysCleared() bool {
- _, ok := m.clearedFields[paymentorder.FieldSubscriptionDays]
- return ok
-}
-
-// ResetSubscriptionDays resets all changes to the "subscription_days" field.
-func (m *PaymentOrderMutation) ResetSubscriptionDays() {
- m.subscription_days = nil
- m.addsubscription_days = nil
- delete(m.clearedFields, paymentorder.FieldSubscriptionDays)
+// SetField sets the value of a field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *IdentityAdoptionDecisionMutation) SetField(name string, value ent.Value) error {
+ switch name {
+ case identityadoptiondecision.FieldCreatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCreatedAt(v)
+ return nil
+ case identityadoptiondecision.FieldUpdatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUpdatedAt(v)
+ return nil
+ case identityadoptiondecision.FieldPendingAuthSessionID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetPendingAuthSessionID(v)
+ return nil
+ case identityadoptiondecision.FieldIdentityID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetIdentityID(v)
+ return nil
+ case identityadoptiondecision.FieldAdoptDisplayName:
+ v, ok := value.(bool)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetAdoptDisplayName(v)
+ return nil
+ case identityadoptiondecision.FieldAdoptAvatar:
+ v, ok := value.(bool)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetAdoptAvatar(v)
+ return nil
+ case identityadoptiondecision.FieldDecidedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetDecidedAt(v)
+ return nil
+ }
+ return fmt.Errorf("unknown IdentityAdoptionDecision field %s", name)
}
-// SetProviderInstanceID sets the "provider_instance_id" field.
-func (m *PaymentOrderMutation) SetProviderInstanceID(s string) {
- m.provider_instance_id = &s
+// AddedFields returns all numeric fields that were incremented/decremented during
+// this mutation.
+func (m *IdentityAdoptionDecisionMutation) AddedFields() []string {
+ var fields []string
+ return fields
}
-// ProviderInstanceID returns the value of the "provider_instance_id" field in the mutation.
-func (m *PaymentOrderMutation) ProviderInstanceID() (r string, exists bool) {
- v := m.provider_instance_id
- if v == nil {
- return
+// AddedField returns the numeric value that was incremented/decremented on a field
+// with the given name. The second boolean return value indicates that this field
+// was not set, or was not defined in the schema.
+func (m *IdentityAdoptionDecisionMutation) AddedField(name string) (ent.Value, bool) {
+ switch name {
}
- return *v, true
+ return nil, false
}
-// OldProviderInstanceID returns the old "provider_instance_id" field's value of the PaymentOrder entity.
-// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentOrderMutation) OldProviderInstanceID(ctx context.Context) (v *string, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldProviderInstanceID is only allowed on UpdateOne operations")
- }
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldProviderInstanceID requires an ID field in the mutation")
- }
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldProviderInstanceID: %w", err)
+// AddField adds the value to the field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *IdentityAdoptionDecisionMutation) AddField(name string, value ent.Value) error {
+ switch name {
}
- return oldValue.ProviderInstanceID, nil
+ return fmt.Errorf("unknown IdentityAdoptionDecision numeric field %s", name)
}
-// ClearProviderInstanceID clears the value of the "provider_instance_id" field.
-func (m *PaymentOrderMutation) ClearProviderInstanceID() {
- m.provider_instance_id = nil
- m.clearedFields[paymentorder.FieldProviderInstanceID] = struct{}{}
+// ClearedFields returns all nullable fields that were cleared during this
+// mutation.
+func (m *IdentityAdoptionDecisionMutation) ClearedFields() []string {
+ var fields []string
+ if m.FieldCleared(identityadoptiondecision.FieldIdentityID) {
+ fields = append(fields, identityadoptiondecision.FieldIdentityID)
+ }
+ return fields
}
-// ProviderInstanceIDCleared returns if the "provider_instance_id" field was cleared in this mutation.
-func (m *PaymentOrderMutation) ProviderInstanceIDCleared() bool {
- _, ok := m.clearedFields[paymentorder.FieldProviderInstanceID]
+// FieldCleared returns a boolean indicating if a field with the given name was
+// cleared in this mutation.
+func (m *IdentityAdoptionDecisionMutation) FieldCleared(name string) bool {
+ _, ok := m.clearedFields[name]
return ok
}
-// ResetProviderInstanceID resets all changes to the "provider_instance_id" field.
-func (m *PaymentOrderMutation) ResetProviderInstanceID() {
- m.provider_instance_id = nil
- delete(m.clearedFields, paymentorder.FieldProviderInstanceID)
-}
-
-// SetStatus sets the "status" field.
-func (m *PaymentOrderMutation) SetStatus(s string) {
- m.status = &s
+// ClearField clears the value of the field with the given name. It returns an
+// error if the field is not defined in the schema.
+func (m *IdentityAdoptionDecisionMutation) ClearField(name string) error {
+ switch name {
+ case identityadoptiondecision.FieldIdentityID:
+ m.ClearIdentityID()
+ return nil
+ }
+ return fmt.Errorf("unknown IdentityAdoptionDecision nullable field %s", name)
}
-// Status returns the value of the "status" field in the mutation.
-func (m *PaymentOrderMutation) Status() (r string, exists bool) {
- v := m.status
- if v == nil {
- return
+// ResetField resets all changes in the mutation for the field with the given name.
+// It returns an error if the field is not defined in the schema.
+func (m *IdentityAdoptionDecisionMutation) ResetField(name string) error {
+ switch name {
+ case identityadoptiondecision.FieldCreatedAt:
+ m.ResetCreatedAt()
+ return nil
+ case identityadoptiondecision.FieldUpdatedAt:
+ m.ResetUpdatedAt()
+ return nil
+ case identityadoptiondecision.FieldPendingAuthSessionID:
+ m.ResetPendingAuthSessionID()
+ return nil
+ case identityadoptiondecision.FieldIdentityID:
+ m.ResetIdentityID()
+ return nil
+ case identityadoptiondecision.FieldAdoptDisplayName:
+ m.ResetAdoptDisplayName()
+ return nil
+ case identityadoptiondecision.FieldAdoptAvatar:
+ m.ResetAdoptAvatar()
+ return nil
+ case identityadoptiondecision.FieldDecidedAt:
+ m.ResetDecidedAt()
+ return nil
}
- return *v, true
+ return fmt.Errorf("unknown IdentityAdoptionDecision field %s", name)
}
-// OldStatus returns the old "status" field's value of the PaymentOrder entity.
-// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentOrderMutation) OldStatus(ctx context.Context) (v string, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldStatus is only allowed on UpdateOne operations")
- }
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldStatus requires an ID field in the mutation")
+// AddedEdges returns all edge names that were set/added in this mutation.
+func (m *IdentityAdoptionDecisionMutation) AddedEdges() []string {
+ edges := make([]string, 0, 2)
+ if m.pending_auth_session != nil {
+ edges = append(edges, identityadoptiondecision.EdgePendingAuthSession)
}
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldStatus: %w", err)
+ if m.identity != nil {
+ edges = append(edges, identityadoptiondecision.EdgeIdentity)
}
- return oldValue.Status, nil
+ return edges
}
-// ResetStatus resets all changes to the "status" field.
-func (m *PaymentOrderMutation) ResetStatus() {
- m.status = nil
+// AddedIDs returns all IDs (to other nodes) that were added for the given edge
+// name in this mutation.
+func (m *IdentityAdoptionDecisionMutation) AddedIDs(name string) []ent.Value {
+ switch name {
+ case identityadoptiondecision.EdgePendingAuthSession:
+ if id := m.pending_auth_session; id != nil {
+ return []ent.Value{*id}
+ }
+ case identityadoptiondecision.EdgeIdentity:
+ if id := m.identity; id != nil {
+ return []ent.Value{*id}
+ }
+ }
+ return nil
}
-// SetRefundAmount sets the "refund_amount" field.
-func (m *PaymentOrderMutation) SetRefundAmount(f float64) {
- m.refund_amount = &f
- m.addrefund_amount = nil
+// RemovedEdges returns all edge names that were removed in this mutation.
+func (m *IdentityAdoptionDecisionMutation) RemovedEdges() []string {
+ edges := make([]string, 0, 2)
+ return edges
}
-// RefundAmount returns the value of the "refund_amount" field in the mutation.
-func (m *PaymentOrderMutation) RefundAmount() (r float64, exists bool) {
- v := m.refund_amount
- if v == nil {
- return
- }
- return *v, true
+// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
+// the given name in this mutation.
+func (m *IdentityAdoptionDecisionMutation) RemovedIDs(name string) []ent.Value {
+ return nil
}
-// OldRefundAmount returns the old "refund_amount" field's value of the PaymentOrder entity.
-// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentOrderMutation) OldRefundAmount(ctx context.Context) (v float64, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldRefundAmount is only allowed on UpdateOne operations")
- }
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldRefundAmount requires an ID field in the mutation")
+// ClearedEdges returns all edge names that were cleared in this mutation.
+func (m *IdentityAdoptionDecisionMutation) ClearedEdges() []string {
+ edges := make([]string, 0, 2)
+ if m.clearedpending_auth_session {
+ edges = append(edges, identityadoptiondecision.EdgePendingAuthSession)
}
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldRefundAmount: %w", err)
+ if m.clearedidentity {
+ edges = append(edges, identityadoptiondecision.EdgeIdentity)
}
- return oldValue.RefundAmount, nil
+ return edges
}
-// AddRefundAmount adds f to the "refund_amount" field.
-func (m *PaymentOrderMutation) AddRefundAmount(f float64) {
- if m.addrefund_amount != nil {
- *m.addrefund_amount += f
- } else {
- m.addrefund_amount = &f
+// EdgeCleared returns a boolean which indicates if the edge with the given name
+// was cleared in this mutation.
+func (m *IdentityAdoptionDecisionMutation) EdgeCleared(name string) bool {
+ switch name {
+ case identityadoptiondecision.EdgePendingAuthSession:
+ return m.clearedpending_auth_session
+ case identityadoptiondecision.EdgeIdentity:
+ return m.clearedidentity
}
+ return false
}
-// AddedRefundAmount returns the value that was added to the "refund_amount" field in this mutation.
-func (m *PaymentOrderMutation) AddedRefundAmount() (r float64, exists bool) {
- v := m.addrefund_amount
- if v == nil {
- return
+// ClearEdge clears the value of the edge with the given name. It returns an error
+// if that edge is not defined in the schema.
+func (m *IdentityAdoptionDecisionMutation) ClearEdge(name string) error {
+ switch name {
+ case identityadoptiondecision.EdgePendingAuthSession:
+ m.ClearPendingAuthSession()
+ return nil
+ case identityadoptiondecision.EdgeIdentity:
+ m.ClearIdentity()
+ return nil
}
- return *v, true
+ return fmt.Errorf("unknown IdentityAdoptionDecision unique edge %s", name)
}
-// ResetRefundAmount resets all changes to the "refund_amount" field.
-func (m *PaymentOrderMutation) ResetRefundAmount() {
- m.refund_amount = nil
- m.addrefund_amount = nil
+// ResetEdge resets all changes to the edge with the given name in this mutation.
+// It returns an error if the edge is not defined in the schema.
+func (m *IdentityAdoptionDecisionMutation) ResetEdge(name string) error {
+ switch name {
+ case identityadoptiondecision.EdgePendingAuthSession:
+ m.ResetPendingAuthSession()
+ return nil
+ case identityadoptiondecision.EdgeIdentity:
+ m.ResetIdentity()
+ return nil
+ }
+ return fmt.Errorf("unknown IdentityAdoptionDecision edge %s", name)
}
-// SetRefundReason sets the "refund_reason" field.
-func (m *PaymentOrderMutation) SetRefundReason(s string) {
- m.refund_reason = &s
+// PaymentAuditLogMutation represents an operation that mutates the PaymentAuditLog nodes in the graph.
+type PaymentAuditLogMutation struct {
+ config
+ op Op
+ typ string
+ id *int64
+ order_id *string
+ action *string
+ detail *string
+ operator *string
+ created_at *time.Time
+ clearedFields map[string]struct{}
+ done bool
+ oldValue func(context.Context) (*PaymentAuditLog, error)
+ predicates []predicate.PaymentAuditLog
}
-// RefundReason returns the value of the "refund_reason" field in the mutation.
-func (m *PaymentOrderMutation) RefundReason() (r string, exists bool) {
- v := m.refund_reason
- if v == nil {
- return
- }
- return *v, true
-}
+var _ ent.Mutation = (*PaymentAuditLogMutation)(nil)
-// OldRefundReason returns the old "refund_reason" field's value of the PaymentOrder entity.
-// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentOrderMutation) OldRefundReason(ctx context.Context) (v *string, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldRefundReason is only allowed on UpdateOne operations")
+// paymentauditlogOption allows management of the mutation configuration using functional options.
+type paymentauditlogOption func(*PaymentAuditLogMutation)
+
+// newPaymentAuditLogMutation creates new mutation for the PaymentAuditLog entity.
+func newPaymentAuditLogMutation(c config, op Op, opts ...paymentauditlogOption) *PaymentAuditLogMutation {
+ m := &PaymentAuditLogMutation{
+ config: c,
+ op: op,
+ typ: TypePaymentAuditLog,
+ clearedFields: make(map[string]struct{}),
}
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldRefundReason requires an ID field in the mutation")
+ for _, opt := range opts {
+ opt(m)
}
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldRefundReason: %w", err)
+ return m
+}
+
+// withPaymentAuditLogID sets the ID field of the mutation.
+func withPaymentAuditLogID(id int64) paymentauditlogOption {
+ return func(m *PaymentAuditLogMutation) {
+ var (
+ err error
+ once sync.Once
+ value *PaymentAuditLog
+ )
+ m.oldValue = func(ctx context.Context) (*PaymentAuditLog, error) {
+ once.Do(func() {
+ if m.done {
+ err = errors.New("querying old values post mutation is not allowed")
+ } else {
+ value, err = m.Client().PaymentAuditLog.Get(ctx, id)
+ }
+ })
+ return value, err
+ }
+ m.id = &id
}
- return oldValue.RefundReason, nil
}
-// ClearRefundReason clears the value of the "refund_reason" field.
-func (m *PaymentOrderMutation) ClearRefundReason() {
- m.refund_reason = nil
- m.clearedFields[paymentorder.FieldRefundReason] = struct{}{}
+// withPaymentAuditLog sets the old PaymentAuditLog of the mutation.
+func withPaymentAuditLog(node *PaymentAuditLog) paymentauditlogOption {
+ return func(m *PaymentAuditLogMutation) {
+ m.oldValue = func(context.Context) (*PaymentAuditLog, error) {
+ return node, nil
+ }
+ m.id = &node.ID
+ }
}
-// RefundReasonCleared returns if the "refund_reason" field was cleared in this mutation.
-func (m *PaymentOrderMutation) RefundReasonCleared() bool {
- _, ok := m.clearedFields[paymentorder.FieldRefundReason]
- return ok
+// Client returns a new `ent.Client` from the mutation. If the mutation was
+// executed in a transaction (ent.Tx), a transactional client is returned.
+func (m PaymentAuditLogMutation) Client() *Client {
+ client := &Client{config: m.config}
+ client.init()
+ return client
}
-// ResetRefundReason resets all changes to the "refund_reason" field.
-func (m *PaymentOrderMutation) ResetRefundReason() {
- m.refund_reason = nil
- delete(m.clearedFields, paymentorder.FieldRefundReason)
+// Tx returns an `ent.Tx` for mutations that were executed in transactions;
+// it returns an error otherwise.
+func (m PaymentAuditLogMutation) Tx() (*Tx, error) {
+ if _, ok := m.driver.(*txDriver); !ok {
+ return nil, errors.New("ent: mutation is not running in a transaction")
+ }
+ tx := &Tx{config: m.config}
+ tx.init()
+ return tx, nil
}
-// SetRefundAt sets the "refund_at" field.
-func (m *PaymentOrderMutation) SetRefundAt(t time.Time) {
- m.refund_at = &t
+// ID returns the ID value in the mutation. Note that the ID is only available
+// if it was provided to the builder or after it was returned from the database.
+func (m *PaymentAuditLogMutation) ID() (id int64, exists bool) {
+ if m.id == nil {
+ return
+ }
+ return *m.id, true
}
-// RefundAt returns the value of the "refund_at" field in the mutation.
-func (m *PaymentOrderMutation) RefundAt() (r time.Time, exists bool) {
- v := m.refund_at
+// IDs queries the database and returns the entity ids that match the mutation's predicate.
+// That means, if the mutation is applied within a transaction with an isolation level such
+// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
+// or updated by the mutation.
+func (m *PaymentAuditLogMutation) IDs(ctx context.Context) ([]int64, error) {
+ switch {
+ case m.op.Is(OpUpdateOne | OpDeleteOne):
+ id, exists := m.ID()
+ if exists {
+ return []int64{id}, nil
+ }
+ fallthrough
+ case m.op.Is(OpUpdate | OpDelete):
+ return m.Client().PaymentAuditLog.Query().Where(m.predicates...).IDs(ctx)
+ default:
+ return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
+ }
+}
+
+// SetOrderID sets the "order_id" field.
+func (m *PaymentAuditLogMutation) SetOrderID(s string) {
+ m.order_id = &s
+}
+
+// OrderID returns the value of the "order_id" field in the mutation.
+func (m *PaymentAuditLogMutation) OrderID() (r string, exists bool) {
+ v := m.order_id
if v == nil {
return
}
return *v, true
}
-// OldRefundAt returns the old "refund_at" field's value of the PaymentOrder entity.
-// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// OldOrderID returns the old "order_id" field's value of the PaymentAuditLog entity.
+// If the PaymentAuditLog object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentOrderMutation) OldRefundAt(ctx context.Context) (v *time.Time, err error) {
+func (m *PaymentAuditLogMutation) OldOrderID(ctx context.Context) (v string, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldRefundAt is only allowed on UpdateOne operations")
+ return v, errors.New("OldOrderID is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldRefundAt requires an ID field in the mutation")
+ return v, errors.New("OldOrderID requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldRefundAt: %w", err)
+ return v, fmt.Errorf("querying old value for OldOrderID: %w", err)
}
- return oldValue.RefundAt, nil
-}
-
-// ClearRefundAt clears the value of the "refund_at" field.
-func (m *PaymentOrderMutation) ClearRefundAt() {
- m.refund_at = nil
- m.clearedFields[paymentorder.FieldRefundAt] = struct{}{}
-}
-
-// RefundAtCleared returns if the "refund_at" field was cleared in this mutation.
-func (m *PaymentOrderMutation) RefundAtCleared() bool {
- _, ok := m.clearedFields[paymentorder.FieldRefundAt]
- return ok
+ return oldValue.OrderID, nil
}
-// ResetRefundAt resets all changes to the "refund_at" field.
-func (m *PaymentOrderMutation) ResetRefundAt() {
- m.refund_at = nil
- delete(m.clearedFields, paymentorder.FieldRefundAt)
+// ResetOrderID resets all changes to the "order_id" field.
+func (m *PaymentAuditLogMutation) ResetOrderID() {
+ m.order_id = nil
}
-// SetForceRefund sets the "force_refund" field.
-func (m *PaymentOrderMutation) SetForceRefund(b bool) {
- m.force_refund = &b
+// SetAction sets the "action" field.
+func (m *PaymentAuditLogMutation) SetAction(s string) {
+ m.action = &s
}
-// ForceRefund returns the value of the "force_refund" field in the mutation.
-func (m *PaymentOrderMutation) ForceRefund() (r bool, exists bool) {
- v := m.force_refund
+// Action returns the value of the "action" field in the mutation.
+func (m *PaymentAuditLogMutation) Action() (r string, exists bool) {
+ v := m.action
if v == nil {
return
}
return *v, true
}
-// OldForceRefund returns the old "force_refund" field's value of the PaymentOrder entity.
-// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// OldAction returns the old "action" field's value of the PaymentAuditLog entity.
+// If the PaymentAuditLog object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentOrderMutation) OldForceRefund(ctx context.Context) (v bool, err error) {
+func (m *PaymentAuditLogMutation) OldAction(ctx context.Context) (v string, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldForceRefund is only allowed on UpdateOne operations")
+ return v, errors.New("OldAction is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldForceRefund requires an ID field in the mutation")
+ return v, errors.New("OldAction requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldForceRefund: %w", err)
+ return v, fmt.Errorf("querying old value for OldAction: %w", err)
}
- return oldValue.ForceRefund, nil
+ return oldValue.Action, nil
}
-// ResetForceRefund resets all changes to the "force_refund" field.
-func (m *PaymentOrderMutation) ResetForceRefund() {
- m.force_refund = nil
+// ResetAction resets all changes to the "action" field.
+func (m *PaymentAuditLogMutation) ResetAction() {
+ m.action = nil
}
-// SetRefundRequestedAt sets the "refund_requested_at" field.
-func (m *PaymentOrderMutation) SetRefundRequestedAt(t time.Time) {
- m.refund_requested_at = &t
+// SetDetail sets the "detail" field.
+func (m *PaymentAuditLogMutation) SetDetail(s string) {
+ m.detail = &s
}
-// RefundRequestedAt returns the value of the "refund_requested_at" field in the mutation.
-func (m *PaymentOrderMutation) RefundRequestedAt() (r time.Time, exists bool) {
- v := m.refund_requested_at
+// Detail returns the value of the "detail" field in the mutation.
+func (m *PaymentAuditLogMutation) Detail() (r string, exists bool) {
+ v := m.detail
if v == nil {
return
}
return *v, true
}
-// OldRefundRequestedAt returns the old "refund_requested_at" field's value of the PaymentOrder entity.
-// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// OldDetail returns the old "detail" field's value of the PaymentAuditLog entity.
+// If the PaymentAuditLog object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentOrderMutation) OldRefundRequestedAt(ctx context.Context) (v *time.Time, err error) {
+func (m *PaymentAuditLogMutation) OldDetail(ctx context.Context) (v string, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldRefundRequestedAt is only allowed on UpdateOne operations")
+ return v, errors.New("OldDetail is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldRefundRequestedAt requires an ID field in the mutation")
+ return v, errors.New("OldDetail requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldRefundRequestedAt: %w", err)
+ return v, fmt.Errorf("querying old value for OldDetail: %w", err)
}
- return oldValue.RefundRequestedAt, nil
-}
-
-// ClearRefundRequestedAt clears the value of the "refund_requested_at" field.
-func (m *PaymentOrderMutation) ClearRefundRequestedAt() {
- m.refund_requested_at = nil
- m.clearedFields[paymentorder.FieldRefundRequestedAt] = struct{}{}
-}
-
-// RefundRequestedAtCleared returns if the "refund_requested_at" field was cleared in this mutation.
-func (m *PaymentOrderMutation) RefundRequestedAtCleared() bool {
- _, ok := m.clearedFields[paymentorder.FieldRefundRequestedAt]
- return ok
+ return oldValue.Detail, nil
}
-// ResetRefundRequestedAt resets all changes to the "refund_requested_at" field.
-func (m *PaymentOrderMutation) ResetRefundRequestedAt() {
- m.refund_requested_at = nil
- delete(m.clearedFields, paymentorder.FieldRefundRequestedAt)
+// ResetDetail resets all changes to the "detail" field.
+func (m *PaymentAuditLogMutation) ResetDetail() {
+ m.detail = nil
}
-// SetRefundRequestReason sets the "refund_request_reason" field.
-func (m *PaymentOrderMutation) SetRefundRequestReason(s string) {
- m.refund_request_reason = &s
+// SetOperator sets the "operator" field.
+func (m *PaymentAuditLogMutation) SetOperator(s string) {
+ m.operator = &s
}
-// RefundRequestReason returns the value of the "refund_request_reason" field in the mutation.
-func (m *PaymentOrderMutation) RefundRequestReason() (r string, exists bool) {
- v := m.refund_request_reason
+// Operator returns the value of the "operator" field in the mutation.
+func (m *PaymentAuditLogMutation) Operator() (r string, exists bool) {
+ v := m.operator
if v == nil {
return
}
return *v, true
}
-// OldRefundRequestReason returns the old "refund_request_reason" field's value of the PaymentOrder entity.
-// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// OldOperator returns the old "operator" field's value of the PaymentAuditLog entity.
+// If the PaymentAuditLog object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentOrderMutation) OldRefundRequestReason(ctx context.Context) (v *string, err error) {
+func (m *PaymentAuditLogMutation) OldOperator(ctx context.Context) (v string, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldRefundRequestReason is only allowed on UpdateOne operations")
+ return v, errors.New("OldOperator is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldRefundRequestReason requires an ID field in the mutation")
+ return v, errors.New("OldOperator requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldRefundRequestReason: %w", err)
+ return v, fmt.Errorf("querying old value for OldOperator: %w", err)
}
- return oldValue.RefundRequestReason, nil
+ return oldValue.Operator, nil
}
-// ClearRefundRequestReason clears the value of the "refund_request_reason" field.
-func (m *PaymentOrderMutation) ClearRefundRequestReason() {
- m.refund_request_reason = nil
- m.clearedFields[paymentorder.FieldRefundRequestReason] = struct{}{}
-}
-
-// RefundRequestReasonCleared returns if the "refund_request_reason" field was cleared in this mutation.
-func (m *PaymentOrderMutation) RefundRequestReasonCleared() bool {
- _, ok := m.clearedFields[paymentorder.FieldRefundRequestReason]
- return ok
-}
-
-// ResetRefundRequestReason resets all changes to the "refund_request_reason" field.
-func (m *PaymentOrderMutation) ResetRefundRequestReason() {
- m.refund_request_reason = nil
- delete(m.clearedFields, paymentorder.FieldRefundRequestReason)
+// ResetOperator resets all changes to the "operator" field.
+func (m *PaymentAuditLogMutation) ResetOperator() {
+ m.operator = nil
}
-// SetRefundRequestedBy sets the "refund_requested_by" field.
-func (m *PaymentOrderMutation) SetRefundRequestedBy(s string) {
- m.refund_requested_by = &s
+// SetCreatedAt sets the "created_at" field.
+func (m *PaymentAuditLogMutation) SetCreatedAt(t time.Time) {
+ m.created_at = &t
}
-// RefundRequestedBy returns the value of the "refund_requested_by" field in the mutation.
-func (m *PaymentOrderMutation) RefundRequestedBy() (r string, exists bool) {
- v := m.refund_requested_by
+// CreatedAt returns the value of the "created_at" field in the mutation.
+func (m *PaymentAuditLogMutation) CreatedAt() (r time.Time, exists bool) {
+ v := m.created_at
if v == nil {
return
}
return *v, true
}
-// OldRefundRequestedBy returns the old "refund_requested_by" field's value of the PaymentOrder entity.
-// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// OldCreatedAt returns the old "created_at" field's value of the PaymentAuditLog entity.
+// If the PaymentAuditLog object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentOrderMutation) OldRefundRequestedBy(ctx context.Context) (v *string, err error) {
+func (m *PaymentAuditLogMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldRefundRequestedBy is only allowed on UpdateOne operations")
+ return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldRefundRequestedBy requires an ID field in the mutation")
+ return v, errors.New("OldCreatedAt requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldRefundRequestedBy: %w", err)
+ return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err)
}
- return oldValue.RefundRequestedBy, nil
+ return oldValue.CreatedAt, nil
}
-// ClearRefundRequestedBy clears the value of the "refund_requested_by" field.
-func (m *PaymentOrderMutation) ClearRefundRequestedBy() {
- m.refund_requested_by = nil
- m.clearedFields[paymentorder.FieldRefundRequestedBy] = struct{}{}
+// ResetCreatedAt resets all changes to the "created_at" field.
+func (m *PaymentAuditLogMutation) ResetCreatedAt() {
+ m.created_at = nil
}
-// RefundRequestedByCleared returns if the "refund_requested_by" field was cleared in this mutation.
-func (m *PaymentOrderMutation) RefundRequestedByCleared() bool {
- _, ok := m.clearedFields[paymentorder.FieldRefundRequestedBy]
- return ok
+// Where appends a list predicates to the PaymentAuditLogMutation builder.
+func (m *PaymentAuditLogMutation) Where(ps ...predicate.PaymentAuditLog) {
+ m.predicates = append(m.predicates, ps...)
}
-// ResetRefundRequestedBy resets all changes to the "refund_requested_by" field.
-func (m *PaymentOrderMutation) ResetRefundRequestedBy() {
- m.refund_requested_by = nil
- delete(m.clearedFields, paymentorder.FieldRefundRequestedBy)
+// WhereP appends storage-level predicates to the PaymentAuditLogMutation builder. Using this method,
+// users can use type-assertion to append predicates that do not depend on any generated package.
+func (m *PaymentAuditLogMutation) WhereP(ps ...func(*sql.Selector)) {
+ p := make([]predicate.PaymentAuditLog, len(ps))
+ for i := range ps {
+ p[i] = ps[i]
+ }
+ m.Where(p...)
}
-// SetExpiresAt sets the "expires_at" field.
-func (m *PaymentOrderMutation) SetExpiresAt(t time.Time) {
- m.expires_at = &t
+// Op returns the operation name.
+func (m *PaymentAuditLogMutation) Op() Op {
+ return m.op
}
-// ExpiresAt returns the value of the "expires_at" field in the mutation.
-func (m *PaymentOrderMutation) ExpiresAt() (r time.Time, exists bool) {
- v := m.expires_at
- if v == nil {
- return
- }
- return *v, true
+// SetOp allows setting the mutation operation.
+func (m *PaymentAuditLogMutation) SetOp(op Op) {
+ m.op = op
}
-// OldExpiresAt returns the old "expires_at" field's value of the PaymentOrder entity.
-// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentOrderMutation) OldExpiresAt(ctx context.Context) (v time.Time, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldExpiresAt is only allowed on UpdateOne operations")
+// Type returns the node type of this mutation (PaymentAuditLog).
+func (m *PaymentAuditLogMutation) Type() string {
+ return m.typ
+}
+
+// Fields returns all fields that were changed during this mutation. Note that in
+// order to get all numeric fields that were incremented/decremented, call
+// AddedFields().
+func (m *PaymentAuditLogMutation) Fields() []string {
+ fields := make([]string, 0, 5)
+ if m.order_id != nil {
+ fields = append(fields, paymentauditlog.FieldOrderID)
}
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldExpiresAt requires an ID field in the mutation")
+ if m.action != nil {
+ fields = append(fields, paymentauditlog.FieldAction)
}
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldExpiresAt: %w", err)
+ if m.detail != nil {
+ fields = append(fields, paymentauditlog.FieldDetail)
}
- return oldValue.ExpiresAt, nil
-}
-
-// ResetExpiresAt resets all changes to the "expires_at" field.
-func (m *PaymentOrderMutation) ResetExpiresAt() {
- m.expires_at = nil
+ if m.operator != nil {
+ fields = append(fields, paymentauditlog.FieldOperator)
+ }
+ if m.created_at != nil {
+ fields = append(fields, paymentauditlog.FieldCreatedAt)
+ }
+ return fields
}
-// SetPaidAt sets the "paid_at" field.
-func (m *PaymentOrderMutation) SetPaidAt(t time.Time) {
- m.paid_at = &t
+// Field returns the value of a field with the given name. The second boolean
+// return value indicates that this field was not set, or was not defined in the
+// schema.
+func (m *PaymentAuditLogMutation) Field(name string) (ent.Value, bool) {
+ switch name {
+ case paymentauditlog.FieldOrderID:
+ return m.OrderID()
+ case paymentauditlog.FieldAction:
+ return m.Action()
+ case paymentauditlog.FieldDetail:
+ return m.Detail()
+ case paymentauditlog.FieldOperator:
+ return m.Operator()
+ case paymentauditlog.FieldCreatedAt:
+ return m.CreatedAt()
+ }
+ return nil, false
}
-// PaidAt returns the value of the "paid_at" field in the mutation.
-func (m *PaymentOrderMutation) PaidAt() (r time.Time, exists bool) {
- v := m.paid_at
- if v == nil {
- return
+// OldField returns the old value of the field from the database. An error is
+// returned if the mutation operation is not UpdateOne, or the query to the
+// database failed.
+func (m *PaymentAuditLogMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
+ switch name {
+ case paymentauditlog.FieldOrderID:
+ return m.OldOrderID(ctx)
+ case paymentauditlog.FieldAction:
+ return m.OldAction(ctx)
+ case paymentauditlog.FieldDetail:
+ return m.OldDetail(ctx)
+ case paymentauditlog.FieldOperator:
+ return m.OldOperator(ctx)
+ case paymentauditlog.FieldCreatedAt:
+ return m.OldCreatedAt(ctx)
}
- return *v, true
+ return nil, fmt.Errorf("unknown PaymentAuditLog field %s", name)
}
-// OldPaidAt returns the old "paid_at" field's value of the PaymentOrder entity.
-// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentOrderMutation) OldPaidAt(ctx context.Context) (v *time.Time, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldPaidAt is only allowed on UpdateOne operations")
- }
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldPaidAt requires an ID field in the mutation")
- }
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldPaidAt: %w", err)
+// SetField sets the value of a field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *PaymentAuditLogMutation) SetField(name string, value ent.Value) error {
+ switch name {
+ case paymentauditlog.FieldOrderID:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetOrderID(v)
+ return nil
+ case paymentauditlog.FieldAction:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetAction(v)
+ return nil
+ case paymentauditlog.FieldDetail:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetDetail(v)
+ return nil
+ case paymentauditlog.FieldOperator:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetOperator(v)
+ return nil
+ case paymentauditlog.FieldCreatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCreatedAt(v)
+ return nil
}
- return oldValue.PaidAt, nil
+ return fmt.Errorf("unknown PaymentAuditLog field %s", name)
}
-// ClearPaidAt clears the value of the "paid_at" field.
-func (m *PaymentOrderMutation) ClearPaidAt() {
- m.paid_at = nil
- m.clearedFields[paymentorder.FieldPaidAt] = struct{}{}
+// AddedFields returns all numeric fields that were incremented/decremented during
+// this mutation.
+func (m *PaymentAuditLogMutation) AddedFields() []string {
+ return nil
}
-// PaidAtCleared returns if the "paid_at" field was cleared in this mutation.
-func (m *PaymentOrderMutation) PaidAtCleared() bool {
- _, ok := m.clearedFields[paymentorder.FieldPaidAt]
- return ok
+// AddedField returns the numeric value that was incremented/decremented on a field
+// with the given name. The second boolean return value indicates that this field
+// was not set, or was not defined in the schema.
+func (m *PaymentAuditLogMutation) AddedField(name string) (ent.Value, bool) {
+ return nil, false
}
-// ResetPaidAt resets all changes to the "paid_at" field.
-func (m *PaymentOrderMutation) ResetPaidAt() {
- m.paid_at = nil
- delete(m.clearedFields, paymentorder.FieldPaidAt)
+// AddField adds the value to the field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *PaymentAuditLogMutation) AddField(name string, value ent.Value) error {
+ switch name {
+ }
+ return fmt.Errorf("unknown PaymentAuditLog numeric field %s", name)
}
-// SetCompletedAt sets the "completed_at" field.
-func (m *PaymentOrderMutation) SetCompletedAt(t time.Time) {
- m.completed_at = &t
+// ClearedFields returns all nullable fields that were cleared during this
+// mutation.
+func (m *PaymentAuditLogMutation) ClearedFields() []string {
+ return nil
}
-// CompletedAt returns the value of the "completed_at" field in the mutation.
-func (m *PaymentOrderMutation) CompletedAt() (r time.Time, exists bool) {
- v := m.completed_at
- if v == nil {
- return
- }
- return *v, true
-}
-
-// OldCompletedAt returns the old "completed_at" field's value of the PaymentOrder entity.
-// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentOrderMutation) OldCompletedAt(ctx context.Context) (v *time.Time, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldCompletedAt is only allowed on UpdateOne operations")
- }
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldCompletedAt requires an ID field in the mutation")
- }
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldCompletedAt: %w", err)
- }
- return oldValue.CompletedAt, nil
-}
-
-// ClearCompletedAt clears the value of the "completed_at" field.
-func (m *PaymentOrderMutation) ClearCompletedAt() {
- m.completed_at = nil
- m.clearedFields[paymentorder.FieldCompletedAt] = struct{}{}
-}
-
-// CompletedAtCleared returns if the "completed_at" field was cleared in this mutation.
-func (m *PaymentOrderMutation) CompletedAtCleared() bool {
- _, ok := m.clearedFields[paymentorder.FieldCompletedAt]
+// FieldCleared returns a boolean indicating if a field with the given name was
+// cleared in this mutation.
+func (m *PaymentAuditLogMutation) FieldCleared(name string) bool {
+ _, ok := m.clearedFields[name]
return ok
}
-// ResetCompletedAt resets all changes to the "completed_at" field.
-func (m *PaymentOrderMutation) ResetCompletedAt() {
- m.completed_at = nil
- delete(m.clearedFields, paymentorder.FieldCompletedAt)
-}
-
-// SetFailedAt sets the "failed_at" field.
-func (m *PaymentOrderMutation) SetFailedAt(t time.Time) {
- m.failed_at = &t
+// ClearField clears the value of the field with the given name. It returns an
+// error if the field is not defined in the schema.
+func (m *PaymentAuditLogMutation) ClearField(name string) error {
+ return fmt.Errorf("unknown PaymentAuditLog nullable field %s", name)
}
-// FailedAt returns the value of the "failed_at" field in the mutation.
-func (m *PaymentOrderMutation) FailedAt() (r time.Time, exists bool) {
- v := m.failed_at
- if v == nil {
- return
+// ResetField resets all changes in the mutation for the field with the given name.
+// It returns an error if the field is not defined in the schema.
+func (m *PaymentAuditLogMutation) ResetField(name string) error {
+ switch name {
+ case paymentauditlog.FieldOrderID:
+ m.ResetOrderID()
+ return nil
+ case paymentauditlog.FieldAction:
+ m.ResetAction()
+ return nil
+ case paymentauditlog.FieldDetail:
+ m.ResetDetail()
+ return nil
+ case paymentauditlog.FieldOperator:
+ m.ResetOperator()
+ return nil
+ case paymentauditlog.FieldCreatedAt:
+ m.ResetCreatedAt()
+ return nil
}
- return *v, true
+ return fmt.Errorf("unknown PaymentAuditLog field %s", name)
}
-// OldFailedAt returns the old "failed_at" field's value of the PaymentOrder entity.
-// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentOrderMutation) OldFailedAt(ctx context.Context) (v *time.Time, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldFailedAt is only allowed on UpdateOne operations")
- }
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldFailedAt requires an ID field in the mutation")
- }
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldFailedAt: %w", err)
- }
- return oldValue.FailedAt, nil
+// AddedEdges returns all edge names that were set/added in this mutation.
+func (m *PaymentAuditLogMutation) AddedEdges() []string {
+ edges := make([]string, 0, 0)
+ return edges
}
-// ClearFailedAt clears the value of the "failed_at" field.
-func (m *PaymentOrderMutation) ClearFailedAt() {
- m.failed_at = nil
- m.clearedFields[paymentorder.FieldFailedAt] = struct{}{}
+// AddedIDs returns all IDs (to other nodes) that were added for the given edge
+// name in this mutation.
+func (m *PaymentAuditLogMutation) AddedIDs(name string) []ent.Value {
+ return nil
}
-// FailedAtCleared returns if the "failed_at" field was cleared in this mutation.
-func (m *PaymentOrderMutation) FailedAtCleared() bool {
- _, ok := m.clearedFields[paymentorder.FieldFailedAt]
- return ok
+// RemovedEdges returns all edge names that were removed in this mutation.
+func (m *PaymentAuditLogMutation) RemovedEdges() []string {
+ edges := make([]string, 0, 0)
+ return edges
}
-// ResetFailedAt resets all changes to the "failed_at" field.
-func (m *PaymentOrderMutation) ResetFailedAt() {
- m.failed_at = nil
- delete(m.clearedFields, paymentorder.FieldFailedAt)
+// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
+// the given name in this mutation.
+func (m *PaymentAuditLogMutation) RemovedIDs(name string) []ent.Value {
+ return nil
}
-// SetFailedReason sets the "failed_reason" field.
-func (m *PaymentOrderMutation) SetFailedReason(s string) {
- m.failed_reason = &s
+// ClearedEdges returns all edge names that were cleared in this mutation.
+func (m *PaymentAuditLogMutation) ClearedEdges() []string {
+ edges := make([]string, 0, 0)
+ return edges
}
-// FailedReason returns the value of the "failed_reason" field in the mutation.
-func (m *PaymentOrderMutation) FailedReason() (r string, exists bool) {
- v := m.failed_reason
- if v == nil {
- return
- }
- return *v, true
+// EdgeCleared returns a boolean which indicates if the edge with the given name
+// was cleared in this mutation.
+func (m *PaymentAuditLogMutation) EdgeCleared(name string) bool {
+ return false
}
-// OldFailedReason returns the old "failed_reason" field's value of the PaymentOrder entity.
-// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentOrderMutation) OldFailedReason(ctx context.Context) (v *string, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldFailedReason is only allowed on UpdateOne operations")
- }
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldFailedReason requires an ID field in the mutation")
- }
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldFailedReason: %w", err)
- }
- return oldValue.FailedReason, nil
+// ClearEdge clears the value of the edge with the given name. It returns an error
+// if that edge is not defined in the schema.
+func (m *PaymentAuditLogMutation) ClearEdge(name string) error {
+ return fmt.Errorf("unknown PaymentAuditLog unique edge %s", name)
}
-// ClearFailedReason clears the value of the "failed_reason" field.
-func (m *PaymentOrderMutation) ClearFailedReason() {
- m.failed_reason = nil
- m.clearedFields[paymentorder.FieldFailedReason] = struct{}{}
+// ResetEdge resets all changes to the edge with the given name in this mutation.
+// It returns an error if the edge is not defined in the schema.
+func (m *PaymentAuditLogMutation) ResetEdge(name string) error {
+ return fmt.Errorf("unknown PaymentAuditLog edge %s", name)
}
-// FailedReasonCleared returns if the "failed_reason" field was cleared in this mutation.
-func (m *PaymentOrderMutation) FailedReasonCleared() bool {
- _, ok := m.clearedFields[paymentorder.FieldFailedReason]
- return ok
+// PaymentOrderMutation represents an operation that mutates the PaymentOrder nodes in the graph.
+type PaymentOrderMutation struct {
+ config
+ op Op
+ typ string
+ id *int64
+ user_email *string
+ user_name *string
+ user_notes *string
+ amount *float64
+ addamount *float64
+ pay_amount *float64
+ addpay_amount *float64
+ fee_rate *float64
+ addfee_rate *float64
+ recharge_code *string
+ out_trade_no *string
+ payment_type *string
+ payment_trade_no *string
+ pay_url *string
+ qr_code *string
+ qr_code_img *string
+ order_type *string
+ plan_id *int64
+ addplan_id *int64
+ subscription_group_id *int64
+ addsubscription_group_id *int64
+ subscription_days *int
+ addsubscription_days *int
+ provider_instance_id *string
+ status *string
+ refund_amount *float64
+ addrefund_amount *float64
+ refund_reason *string
+ refund_at *time.Time
+ force_refund *bool
+ refund_requested_at *time.Time
+ refund_request_reason *string
+ refund_requested_by *string
+ expires_at *time.Time
+ paid_at *time.Time
+ completed_at *time.Time
+ failed_at *time.Time
+ failed_reason *string
+ client_ip *string
+ src_host *string
+ src_url *string
+ created_at *time.Time
+ updated_at *time.Time
+ clearedFields map[string]struct{}
+ user *int64
+ cleareduser bool
+ done bool
+ oldValue func(context.Context) (*PaymentOrder, error)
+ predicates []predicate.PaymentOrder
}
-// ResetFailedReason resets all changes to the "failed_reason" field.
-func (m *PaymentOrderMutation) ResetFailedReason() {
- m.failed_reason = nil
- delete(m.clearedFields, paymentorder.FieldFailedReason)
-}
+var _ ent.Mutation = (*PaymentOrderMutation)(nil)
-// SetClientIP sets the "client_ip" field.
-func (m *PaymentOrderMutation) SetClientIP(s string) {
- m.client_ip = &s
-}
+// paymentorderOption allows management of the mutation configuration using functional options.
+type paymentorderOption func(*PaymentOrderMutation)
-// ClientIP returns the value of the "client_ip" field in the mutation.
-func (m *PaymentOrderMutation) ClientIP() (r string, exists bool) {
- v := m.client_ip
- if v == nil {
- return
+// newPaymentOrderMutation creates new mutation for the PaymentOrder entity.
+func newPaymentOrderMutation(c config, op Op, opts ...paymentorderOption) *PaymentOrderMutation {
+ m := &PaymentOrderMutation{
+ config: c,
+ op: op,
+ typ: TypePaymentOrder,
+ clearedFields: make(map[string]struct{}),
}
- return *v, true
+ for _, opt := range opts {
+ opt(m)
+ }
+ return m
}
-// OldClientIP returns the old "client_ip" field's value of the PaymentOrder entity.
-// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentOrderMutation) OldClientIP(ctx context.Context) (v string, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldClientIP is only allowed on UpdateOne operations")
+// withPaymentOrderID sets the ID field of the mutation.
+func withPaymentOrderID(id int64) paymentorderOption {
+ return func(m *PaymentOrderMutation) {
+ var (
+ err error
+ once sync.Once
+ value *PaymentOrder
+ )
+ m.oldValue = func(ctx context.Context) (*PaymentOrder, error) {
+ once.Do(func() {
+ if m.done {
+ err = errors.New("querying old values post mutation is not allowed")
+ } else {
+ value, err = m.Client().PaymentOrder.Get(ctx, id)
+ }
+ })
+ return value, err
+ }
+ m.id = &id
}
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldClientIP requires an ID field in the mutation")
+}
+
+// withPaymentOrder sets the old PaymentOrder of the mutation.
+func withPaymentOrder(node *PaymentOrder) paymentorderOption {
+ return func(m *PaymentOrderMutation) {
+ m.oldValue = func(context.Context) (*PaymentOrder, error) {
+ return node, nil
+ }
+ m.id = &node.ID
}
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldClientIP: %w", err)
+}
+
+// Client returns a new `ent.Client` from the mutation. If the mutation was
+// executed in a transaction (ent.Tx), a transactional client is returned.
+func (m PaymentOrderMutation) Client() *Client {
+ client := &Client{config: m.config}
+ client.init()
+ return client
+}
+
+// Tx returns an `ent.Tx` for mutations that were executed in transactions;
+// it returns an error otherwise.
+func (m PaymentOrderMutation) Tx() (*Tx, error) {
+ if _, ok := m.driver.(*txDriver); !ok {
+ return nil, errors.New("ent: mutation is not running in a transaction")
}
- return oldValue.ClientIP, nil
+ tx := &Tx{config: m.config}
+ tx.init()
+ return tx, nil
}
-// ResetClientIP resets all changes to the "client_ip" field.
-func (m *PaymentOrderMutation) ResetClientIP() {
- m.client_ip = nil
+// ID returns the ID value in the mutation. Note that the ID is only available
+// if it was provided to the builder or after it was returned from the database.
+func (m *PaymentOrderMutation) ID() (id int64, exists bool) {
+ if m.id == nil {
+ return
+ }
+ return *m.id, true
}
-// SetSrcHost sets the "src_host" field.
-func (m *PaymentOrderMutation) SetSrcHost(s string) {
- m.src_host = &s
+// IDs queries the database and returns the entity ids that match the mutation's predicate.
+// That means, if the mutation is applied within a transaction with an isolation level such
+// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
+// or updated by the mutation.
+func (m *PaymentOrderMutation) IDs(ctx context.Context) ([]int64, error) {
+ switch {
+ case m.op.Is(OpUpdateOne | OpDeleteOne):
+ id, exists := m.ID()
+ if exists {
+ return []int64{id}, nil
+ }
+ fallthrough
+ case m.op.Is(OpUpdate | OpDelete):
+ return m.Client().PaymentOrder.Query().Where(m.predicates...).IDs(ctx)
+ default:
+ return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
+ }
}
-// SrcHost returns the value of the "src_host" field in the mutation.
-func (m *PaymentOrderMutation) SrcHost() (r string, exists bool) {
- v := m.src_host
+// SetUserID sets the "user_id" field.
+func (m *PaymentOrderMutation) SetUserID(i int64) {
+ m.user = &i
+}
+
+// UserID returns the value of the "user_id" field in the mutation.
+func (m *PaymentOrderMutation) UserID() (r int64, exists bool) {
+ v := m.user
if v == nil {
return
}
return *v, true
}
-// OldSrcHost returns the old "src_host" field's value of the PaymentOrder entity.
+// OldUserID returns the old "user_id" field's value of the PaymentOrder entity.
// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentOrderMutation) OldSrcHost(ctx context.Context) (v string, err error) {
+func (m *PaymentOrderMutation) OldUserID(ctx context.Context) (v int64, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldSrcHost is only allowed on UpdateOne operations")
+ return v, errors.New("OldUserID is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldSrcHost requires an ID field in the mutation")
+ return v, errors.New("OldUserID requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldSrcHost: %w", err)
+ return v, fmt.Errorf("querying old value for OldUserID: %w", err)
}
- return oldValue.SrcHost, nil
+ return oldValue.UserID, nil
}
-// ResetSrcHost resets all changes to the "src_host" field.
-func (m *PaymentOrderMutation) ResetSrcHost() {
- m.src_host = nil
+// ResetUserID resets all changes to the "user_id" field.
+func (m *PaymentOrderMutation) ResetUserID() {
+ m.user = nil
}
-// SetSrcURL sets the "src_url" field.
-func (m *PaymentOrderMutation) SetSrcURL(s string) {
- m.src_url = &s
+// SetUserEmail sets the "user_email" field.
+func (m *PaymentOrderMutation) SetUserEmail(s string) {
+ m.user_email = &s
}
-// SrcURL returns the value of the "src_url" field in the mutation.
-func (m *PaymentOrderMutation) SrcURL() (r string, exists bool) {
- v := m.src_url
+// UserEmail returns the value of the "user_email" field in the mutation.
+func (m *PaymentOrderMutation) UserEmail() (r string, exists bool) {
+ v := m.user_email
if v == nil {
return
}
return *v, true
}
-// OldSrcURL returns the old "src_url" field's value of the PaymentOrder entity.
+// OldUserEmail returns the old "user_email" field's value of the PaymentOrder entity.
// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentOrderMutation) OldSrcURL(ctx context.Context) (v *string, err error) {
+func (m *PaymentOrderMutation) OldUserEmail(ctx context.Context) (v string, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldSrcURL is only allowed on UpdateOne operations")
+ return v, errors.New("OldUserEmail is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldSrcURL requires an ID field in the mutation")
+ return v, errors.New("OldUserEmail requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldSrcURL: %w", err)
+ return v, fmt.Errorf("querying old value for OldUserEmail: %w", err)
}
- return oldValue.SrcURL, nil
-}
-
-// ClearSrcURL clears the value of the "src_url" field.
-func (m *PaymentOrderMutation) ClearSrcURL() {
- m.src_url = nil
- m.clearedFields[paymentorder.FieldSrcURL] = struct{}{}
-}
-
-// SrcURLCleared returns if the "src_url" field was cleared in this mutation.
-func (m *PaymentOrderMutation) SrcURLCleared() bool {
- _, ok := m.clearedFields[paymentorder.FieldSrcURL]
- return ok
+ return oldValue.UserEmail, nil
}
-// ResetSrcURL resets all changes to the "src_url" field.
-func (m *PaymentOrderMutation) ResetSrcURL() {
- m.src_url = nil
- delete(m.clearedFields, paymentorder.FieldSrcURL)
+// ResetUserEmail resets all changes to the "user_email" field.
+func (m *PaymentOrderMutation) ResetUserEmail() {
+ m.user_email = nil
}
-// SetCreatedAt sets the "created_at" field.
-func (m *PaymentOrderMutation) SetCreatedAt(t time.Time) {
- m.created_at = &t
+// SetUserName sets the "user_name" field.
+func (m *PaymentOrderMutation) SetUserName(s string) {
+ m.user_name = &s
}
-// CreatedAt returns the value of the "created_at" field in the mutation.
-func (m *PaymentOrderMutation) CreatedAt() (r time.Time, exists bool) {
- v := m.created_at
+// UserName returns the value of the "user_name" field in the mutation.
+func (m *PaymentOrderMutation) UserName() (r string, exists bool) {
+ v := m.user_name
if v == nil {
return
}
return *v, true
}
-// OldCreatedAt returns the old "created_at" field's value of the PaymentOrder entity.
+// OldUserName returns the old "user_name" field's value of the PaymentOrder entity.
// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentOrderMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
+func (m *PaymentOrderMutation) OldUserName(ctx context.Context) (v string, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations")
+ return v, errors.New("OldUserName is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldCreatedAt requires an ID field in the mutation")
+ return v, errors.New("OldUserName requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err)
+ return v, fmt.Errorf("querying old value for OldUserName: %w", err)
}
- return oldValue.CreatedAt, nil
+ return oldValue.UserName, nil
}
-// ResetCreatedAt resets all changes to the "created_at" field.
-func (m *PaymentOrderMutation) ResetCreatedAt() {
- m.created_at = nil
+// ResetUserName resets all changes to the "user_name" field.
+func (m *PaymentOrderMutation) ResetUserName() {
+ m.user_name = nil
}
-// SetUpdatedAt sets the "updated_at" field.
-func (m *PaymentOrderMutation) SetUpdatedAt(t time.Time) {
- m.updated_at = &t
+// SetUserNotes sets the "user_notes" field.
+func (m *PaymentOrderMutation) SetUserNotes(s string) {
+ m.user_notes = &s
}
-// UpdatedAt returns the value of the "updated_at" field in the mutation.
-func (m *PaymentOrderMutation) UpdatedAt() (r time.Time, exists bool) {
- v := m.updated_at
+// UserNotes returns the value of the "user_notes" field in the mutation.
+func (m *PaymentOrderMutation) UserNotes() (r string, exists bool) {
+ v := m.user_notes
if v == nil {
return
}
return *v, true
}
-// OldUpdatedAt returns the old "updated_at" field's value of the PaymentOrder entity.
+// OldUserNotes returns the old "user_notes" field's value of the PaymentOrder entity.
// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentOrderMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) {
+func (m *PaymentOrderMutation) OldUserNotes(ctx context.Context) (v *string, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations")
+ return v, errors.New("OldUserNotes is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldUpdatedAt requires an ID field in the mutation")
+ return v, errors.New("OldUserNotes requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err)
+ return v, fmt.Errorf("querying old value for OldUserNotes: %w", err)
}
- return oldValue.UpdatedAt, nil
+ return oldValue.UserNotes, nil
}
-// ResetUpdatedAt resets all changes to the "updated_at" field.
-func (m *PaymentOrderMutation) ResetUpdatedAt() {
- m.updated_at = nil
+// ClearUserNotes clears the value of the "user_notes" field.
+func (m *PaymentOrderMutation) ClearUserNotes() {
+ m.user_notes = nil
+ m.clearedFields[paymentorder.FieldUserNotes] = struct{}{}
}
-// ClearUser clears the "user" edge to the User entity.
-func (m *PaymentOrderMutation) ClearUser() {
- m.cleareduser = true
- m.clearedFields[paymentorder.FieldUserID] = struct{}{}
+// UserNotesCleared returns if the "user_notes" field was cleared in this mutation.
+func (m *PaymentOrderMutation) UserNotesCleared() bool {
+ _, ok := m.clearedFields[paymentorder.FieldUserNotes]
+ return ok
}
-// UserCleared reports if the "user" edge to the User entity was cleared.
-func (m *PaymentOrderMutation) UserCleared() bool {
- return m.cleareduser
+// ResetUserNotes resets all changes to the "user_notes" field.
+func (m *PaymentOrderMutation) ResetUserNotes() {
+ m.user_notes = nil
+ delete(m.clearedFields, paymentorder.FieldUserNotes)
}
-// UserIDs returns the "user" edge IDs in the mutation.
-// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use
-// UserID instead. It exists only for internal usage by the builders.
-func (m *PaymentOrderMutation) UserIDs() (ids []int64) {
- if id := m.user; id != nil {
- ids = append(ids, *id)
+// SetAmount sets the "amount" field.
+func (m *PaymentOrderMutation) SetAmount(f float64) {
+ m.amount = &f
+ m.addamount = nil
+}
+
+// Amount returns the value of the "amount" field in the mutation.
+func (m *PaymentOrderMutation) Amount() (r float64, exists bool) {
+ v := m.amount
+ if v == nil {
+ return
}
- return
+ return *v, true
}
-// ResetUser resets all changes to the "user" edge.
-func (m *PaymentOrderMutation) ResetUser() {
- m.user = nil
- m.cleareduser = false
+// OldAmount returns the old "amount" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldAmount(ctx context.Context) (v float64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldAmount is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldAmount requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldAmount: %w", err)
+ }
+ return oldValue.Amount, nil
}
-// Where appends a list predicates to the PaymentOrderMutation builder.
-func (m *PaymentOrderMutation) Where(ps ...predicate.PaymentOrder) {
- m.predicates = append(m.predicates, ps...)
+// AddAmount adds f to the "amount" field.
+func (m *PaymentOrderMutation) AddAmount(f float64) {
+ if m.addamount != nil {
+ *m.addamount += f
+ } else {
+ m.addamount = &f
+ }
}
-// WhereP appends storage-level predicates to the PaymentOrderMutation builder. Using this method,
-// users can use type-assertion to append predicates that do not depend on any generated package.
-func (m *PaymentOrderMutation) WhereP(ps ...func(*sql.Selector)) {
- p := make([]predicate.PaymentOrder, len(ps))
- for i := range ps {
- p[i] = ps[i]
+// AddedAmount returns the value that was added to the "amount" field in this mutation.
+func (m *PaymentOrderMutation) AddedAmount() (r float64, exists bool) {
+ v := m.addamount
+ if v == nil {
+ return
}
- m.Where(p...)
+ return *v, true
}
-// Op returns the operation name.
-func (m *PaymentOrderMutation) Op() Op {
- return m.op
+// ResetAmount resets all changes to the "amount" field.
+func (m *PaymentOrderMutation) ResetAmount() {
+ m.amount = nil
+ m.addamount = nil
}
-// SetOp allows setting the mutation operation.
-func (m *PaymentOrderMutation) SetOp(op Op) {
- m.op = op
+// SetPayAmount sets the "pay_amount" field.
+func (m *PaymentOrderMutation) SetPayAmount(f float64) {
+ m.pay_amount = &f
+ m.addpay_amount = nil
}
-// Type returns the node type of this mutation (PaymentOrder).
-func (m *PaymentOrderMutation) Type() string {
- return m.typ
+// PayAmount returns the value of the "pay_amount" field in the mutation.
+func (m *PaymentOrderMutation) PayAmount() (r float64, exists bool) {
+ v := m.pay_amount
+ if v == nil {
+ return
+ }
+ return *v, true
}
-// Fields returns all fields that were changed during this mutation. Note that in
-// order to get all numeric fields that were incremented/decremented, call
-// AddedFields().
-func (m *PaymentOrderMutation) Fields() []string {
- fields := make([]string, 0, 37)
- if m.user != nil {
- fields = append(fields, paymentorder.FieldUserID)
- }
- if m.user_email != nil {
- fields = append(fields, paymentorder.FieldUserEmail)
- }
- if m.user_name != nil {
- fields = append(fields, paymentorder.FieldUserName)
+// OldPayAmount returns the old "pay_amount" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldPayAmount(ctx context.Context) (v float64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldPayAmount is only allowed on UpdateOne operations")
}
- if m.user_notes != nil {
- fields = append(fields, paymentorder.FieldUserNotes)
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldPayAmount requires an ID field in the mutation")
}
- if m.amount != nil {
- fields = append(fields, paymentorder.FieldAmount)
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldPayAmount: %w", err)
}
- if m.pay_amount != nil {
- fields = append(fields, paymentorder.FieldPayAmount)
+ return oldValue.PayAmount, nil
+}
+
+// AddPayAmount adds f to the "pay_amount" field.
+func (m *PaymentOrderMutation) AddPayAmount(f float64) {
+ if m.addpay_amount != nil {
+ *m.addpay_amount += f
+ } else {
+ m.addpay_amount = &f
}
- if m.fee_rate != nil {
- fields = append(fields, paymentorder.FieldFeeRate)
+}
+
+// AddedPayAmount returns the value that was added to the "pay_amount" field in this mutation.
+func (m *PaymentOrderMutation) AddedPayAmount() (r float64, exists bool) {
+ v := m.addpay_amount
+ if v == nil {
+ return
}
- if m.recharge_code != nil {
- fields = append(fields, paymentorder.FieldRechargeCode)
+ return *v, true
+}
+
+// ResetPayAmount resets all changes to the "pay_amount" field.
+func (m *PaymentOrderMutation) ResetPayAmount() {
+ m.pay_amount = nil
+ m.addpay_amount = nil
+}
+
+// SetFeeRate sets the "fee_rate" field.
+func (m *PaymentOrderMutation) SetFeeRate(f float64) {
+ m.fee_rate = &f
+ m.addfee_rate = nil
+}
+
+// FeeRate returns the value of the "fee_rate" field in the mutation.
+func (m *PaymentOrderMutation) FeeRate() (r float64, exists bool) {
+ v := m.fee_rate
+ if v == nil {
+ return
}
- if m.out_trade_no != nil {
- fields = append(fields, paymentorder.FieldOutTradeNo)
+ return *v, true
+}
+
+// OldFeeRate returns the old "fee_rate" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldFeeRate(ctx context.Context) (v float64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldFeeRate is only allowed on UpdateOne operations")
}
- if m.payment_type != nil {
- fields = append(fields, paymentorder.FieldPaymentType)
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldFeeRate requires an ID field in the mutation")
}
- if m.payment_trade_no != nil {
- fields = append(fields, paymentorder.FieldPaymentTradeNo)
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldFeeRate: %w", err)
}
- if m.pay_url != nil {
- fields = append(fields, paymentorder.FieldPayURL)
+ return oldValue.FeeRate, nil
+}
+
+// AddFeeRate adds f to the "fee_rate" field.
+func (m *PaymentOrderMutation) AddFeeRate(f float64) {
+ if m.addfee_rate != nil {
+ *m.addfee_rate += f
+ } else {
+ m.addfee_rate = &f
}
- if m.qr_code != nil {
- fields = append(fields, paymentorder.FieldQrCode)
+}
+
+// AddedFeeRate returns the value that was added to the "fee_rate" field in this mutation.
+func (m *PaymentOrderMutation) AddedFeeRate() (r float64, exists bool) {
+ v := m.addfee_rate
+ if v == nil {
+ return
}
- if m.qr_code_img != nil {
- fields = append(fields, paymentorder.FieldQrCodeImg)
+ return *v, true
+}
+
+// ResetFeeRate resets all changes to the "fee_rate" field.
+func (m *PaymentOrderMutation) ResetFeeRate() {
+ m.fee_rate = nil
+ m.addfee_rate = nil
+}
+
+// SetRechargeCode sets the "recharge_code" field.
+func (m *PaymentOrderMutation) SetRechargeCode(s string) {
+ m.recharge_code = &s
+}
+
+// RechargeCode returns the value of the "recharge_code" field in the mutation.
+func (m *PaymentOrderMutation) RechargeCode() (r string, exists bool) {
+ v := m.recharge_code
+ if v == nil {
+ return
}
- if m.order_type != nil {
- fields = append(fields, paymentorder.FieldOrderType)
+ return *v, true
+}
+
+// OldRechargeCode returns the old "recharge_code" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldRechargeCode(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldRechargeCode is only allowed on UpdateOne operations")
}
- if m.plan_id != nil {
- fields = append(fields, paymentorder.FieldPlanID)
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldRechargeCode requires an ID field in the mutation")
}
- if m.subscription_group_id != nil {
- fields = append(fields, paymentorder.FieldSubscriptionGroupID)
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldRechargeCode: %w", err)
}
- if m.subscription_days != nil {
- fields = append(fields, paymentorder.FieldSubscriptionDays)
+ return oldValue.RechargeCode, nil
+}
+
+// ResetRechargeCode resets all changes to the "recharge_code" field.
+func (m *PaymentOrderMutation) ResetRechargeCode() {
+ m.recharge_code = nil
+}
+
+// SetOutTradeNo sets the "out_trade_no" field.
+func (m *PaymentOrderMutation) SetOutTradeNo(s string) {
+ m.out_trade_no = &s
+}
+
+// OutTradeNo returns the value of the "out_trade_no" field in the mutation.
+func (m *PaymentOrderMutation) OutTradeNo() (r string, exists bool) {
+ v := m.out_trade_no
+ if v == nil {
+ return
}
- if m.provider_instance_id != nil {
- fields = append(fields, paymentorder.FieldProviderInstanceID)
+ return *v, true
+}
+
+// OldOutTradeNo returns the old "out_trade_no" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldOutTradeNo(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldOutTradeNo is only allowed on UpdateOne operations")
}
- if m.status != nil {
- fields = append(fields, paymentorder.FieldStatus)
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldOutTradeNo requires an ID field in the mutation")
}
- if m.refund_amount != nil {
- fields = append(fields, paymentorder.FieldRefundAmount)
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldOutTradeNo: %w", err)
}
- if m.refund_reason != nil {
- fields = append(fields, paymentorder.FieldRefundReason)
- }
- if m.refund_at != nil {
- fields = append(fields, paymentorder.FieldRefundAt)
- }
- if m.force_refund != nil {
- fields = append(fields, paymentorder.FieldForceRefund)
- }
- if m.refund_requested_at != nil {
- fields = append(fields, paymentorder.FieldRefundRequestedAt)
+ return oldValue.OutTradeNo, nil
+}
+
+// ResetOutTradeNo resets all changes to the "out_trade_no" field.
+func (m *PaymentOrderMutation) ResetOutTradeNo() {
+ m.out_trade_no = nil
+}
+
+// SetPaymentType sets the "payment_type" field.
+func (m *PaymentOrderMutation) SetPaymentType(s string) {
+ m.payment_type = &s
+}
+
+// PaymentType returns the value of the "payment_type" field in the mutation.
+func (m *PaymentOrderMutation) PaymentType() (r string, exists bool) {
+ v := m.payment_type
+ if v == nil {
+ return
}
- if m.refund_request_reason != nil {
- fields = append(fields, paymentorder.FieldRefundRequestReason)
+ return *v, true
+}
+
+// OldPaymentType returns the old "payment_type" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldPaymentType(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldPaymentType is only allowed on UpdateOne operations")
}
- if m.refund_requested_by != nil {
- fields = append(fields, paymentorder.FieldRefundRequestedBy)
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldPaymentType requires an ID field in the mutation")
}
- if m.expires_at != nil {
- fields = append(fields, paymentorder.FieldExpiresAt)
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldPaymentType: %w", err)
}
- if m.paid_at != nil {
- fields = append(fields, paymentorder.FieldPaidAt)
+ return oldValue.PaymentType, nil
+}
+
+// ResetPaymentType resets all changes to the "payment_type" field.
+func (m *PaymentOrderMutation) ResetPaymentType() {
+ m.payment_type = nil
+}
+
+// SetPaymentTradeNo sets the "payment_trade_no" field.
+func (m *PaymentOrderMutation) SetPaymentTradeNo(s string) {
+ m.payment_trade_no = &s
+}
+
+// PaymentTradeNo returns the value of the "payment_trade_no" field in the mutation.
+func (m *PaymentOrderMutation) PaymentTradeNo() (r string, exists bool) {
+ v := m.payment_trade_no
+ if v == nil {
+ return
}
- if m.completed_at != nil {
- fields = append(fields, paymentorder.FieldCompletedAt)
+ return *v, true
+}
+
+// OldPaymentTradeNo returns the old "payment_trade_no" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldPaymentTradeNo(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldPaymentTradeNo is only allowed on UpdateOne operations")
}
- if m.failed_at != nil {
- fields = append(fields, paymentorder.FieldFailedAt)
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldPaymentTradeNo requires an ID field in the mutation")
}
- if m.failed_reason != nil {
- fields = append(fields, paymentorder.FieldFailedReason)
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldPaymentTradeNo: %w", err)
}
- if m.client_ip != nil {
- fields = append(fields, paymentorder.FieldClientIP)
+ return oldValue.PaymentTradeNo, nil
+}
+
+// ResetPaymentTradeNo resets all changes to the "payment_trade_no" field.
+func (m *PaymentOrderMutation) ResetPaymentTradeNo() {
+ m.payment_trade_no = nil
+}
+
+// SetPayURL sets the "pay_url" field.
+func (m *PaymentOrderMutation) SetPayURL(s string) {
+ m.pay_url = &s
+}
+
+// PayURL returns the value of the "pay_url" field in the mutation.
+func (m *PaymentOrderMutation) PayURL() (r string, exists bool) {
+ v := m.pay_url
+ if v == nil {
+ return
}
- if m.src_host != nil {
- fields = append(fields, paymentorder.FieldSrcHost)
+ return *v, true
+}
+
+// OldPayURL returns the old "pay_url" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldPayURL(ctx context.Context) (v *string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldPayURL is only allowed on UpdateOne operations")
}
- if m.src_url != nil {
- fields = append(fields, paymentorder.FieldSrcURL)
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldPayURL requires an ID field in the mutation")
}
- if m.created_at != nil {
- fields = append(fields, paymentorder.FieldCreatedAt)
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldPayURL: %w", err)
}
- if m.updated_at != nil {
- fields = append(fields, paymentorder.FieldUpdatedAt)
+ return oldValue.PayURL, nil
+}
+
+// ClearPayURL clears the value of the "pay_url" field.
+func (m *PaymentOrderMutation) ClearPayURL() {
+ m.pay_url = nil
+ m.clearedFields[paymentorder.FieldPayURL] = struct{}{}
+}
+
+// PayURLCleared returns if the "pay_url" field was cleared in this mutation.
+func (m *PaymentOrderMutation) PayURLCleared() bool {
+ _, ok := m.clearedFields[paymentorder.FieldPayURL]
+ return ok
+}
+
+// ResetPayURL resets all changes to the "pay_url" field.
+func (m *PaymentOrderMutation) ResetPayURL() {
+ m.pay_url = nil
+ delete(m.clearedFields, paymentorder.FieldPayURL)
+}
+
+// SetQrCode sets the "qr_code" field.
+func (m *PaymentOrderMutation) SetQrCode(s string) {
+ m.qr_code = &s
+}
+
+// QrCode returns the value of the "qr_code" field in the mutation.
+func (m *PaymentOrderMutation) QrCode() (r string, exists bool) {
+ v := m.qr_code
+ if v == nil {
+ return
}
- return fields
+ return *v, true
}
-// Field returns the value of a field with the given name. The second boolean
-// return value indicates that this field was not set, or was not defined in the
-// schema.
-func (m *PaymentOrderMutation) Field(name string) (ent.Value, bool) {
- switch name {
- case paymentorder.FieldUserID:
- return m.UserID()
- case paymentorder.FieldUserEmail:
- return m.UserEmail()
- case paymentorder.FieldUserName:
- return m.UserName()
- case paymentorder.FieldUserNotes:
- return m.UserNotes()
- case paymentorder.FieldAmount:
- return m.Amount()
- case paymentorder.FieldPayAmount:
- return m.PayAmount()
- case paymentorder.FieldFeeRate:
- return m.FeeRate()
- case paymentorder.FieldRechargeCode:
- return m.RechargeCode()
- case paymentorder.FieldOutTradeNo:
- return m.OutTradeNo()
- case paymentorder.FieldPaymentType:
- return m.PaymentType()
- case paymentorder.FieldPaymentTradeNo:
- return m.PaymentTradeNo()
- case paymentorder.FieldPayURL:
- return m.PayURL()
- case paymentorder.FieldQrCode:
- return m.QrCode()
- case paymentorder.FieldQrCodeImg:
- return m.QrCodeImg()
- case paymentorder.FieldOrderType:
- return m.OrderType()
- case paymentorder.FieldPlanID:
- return m.PlanID()
- case paymentorder.FieldSubscriptionGroupID:
- return m.SubscriptionGroupID()
- case paymentorder.FieldSubscriptionDays:
- return m.SubscriptionDays()
- case paymentorder.FieldProviderInstanceID:
- return m.ProviderInstanceID()
- case paymentorder.FieldStatus:
- return m.Status()
- case paymentorder.FieldRefundAmount:
- return m.RefundAmount()
- case paymentorder.FieldRefundReason:
- return m.RefundReason()
- case paymentorder.FieldRefundAt:
- return m.RefundAt()
- case paymentorder.FieldForceRefund:
- return m.ForceRefund()
- case paymentorder.FieldRefundRequestedAt:
- return m.RefundRequestedAt()
- case paymentorder.FieldRefundRequestReason:
- return m.RefundRequestReason()
- case paymentorder.FieldRefundRequestedBy:
- return m.RefundRequestedBy()
- case paymentorder.FieldExpiresAt:
- return m.ExpiresAt()
- case paymentorder.FieldPaidAt:
- return m.PaidAt()
- case paymentorder.FieldCompletedAt:
- return m.CompletedAt()
- case paymentorder.FieldFailedAt:
- return m.FailedAt()
- case paymentorder.FieldFailedReason:
- return m.FailedReason()
- case paymentorder.FieldClientIP:
- return m.ClientIP()
- case paymentorder.FieldSrcHost:
- return m.SrcHost()
- case paymentorder.FieldSrcURL:
- return m.SrcURL()
- case paymentorder.FieldCreatedAt:
- return m.CreatedAt()
- case paymentorder.FieldUpdatedAt:
- return m.UpdatedAt()
+// OldQrCode returns the old "qr_code" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldQrCode(ctx context.Context) (v *string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldQrCode is only allowed on UpdateOne operations")
}
- return nil, false
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldQrCode requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldQrCode: %w", err)
+ }
+ return oldValue.QrCode, nil
}
-// OldField returns the old value of the field from the database. An error is
-// returned if the mutation operation is not UpdateOne, or the query to the
-// database failed.
-func (m *PaymentOrderMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
- switch name {
- case paymentorder.FieldUserID:
- return m.OldUserID(ctx)
- case paymentorder.FieldUserEmail:
- return m.OldUserEmail(ctx)
- case paymentorder.FieldUserName:
- return m.OldUserName(ctx)
- case paymentorder.FieldUserNotes:
- return m.OldUserNotes(ctx)
- case paymentorder.FieldAmount:
- return m.OldAmount(ctx)
- case paymentorder.FieldPayAmount:
- return m.OldPayAmount(ctx)
- case paymentorder.FieldFeeRate:
- return m.OldFeeRate(ctx)
- case paymentorder.FieldRechargeCode:
- return m.OldRechargeCode(ctx)
- case paymentorder.FieldOutTradeNo:
- return m.OldOutTradeNo(ctx)
- case paymentorder.FieldPaymentType:
- return m.OldPaymentType(ctx)
- case paymentorder.FieldPaymentTradeNo:
- return m.OldPaymentTradeNo(ctx)
- case paymentorder.FieldPayURL:
- return m.OldPayURL(ctx)
- case paymentorder.FieldQrCode:
- return m.OldQrCode(ctx)
- case paymentorder.FieldQrCodeImg:
- return m.OldQrCodeImg(ctx)
- case paymentorder.FieldOrderType:
- return m.OldOrderType(ctx)
- case paymentorder.FieldPlanID:
- return m.OldPlanID(ctx)
- case paymentorder.FieldSubscriptionGroupID:
- return m.OldSubscriptionGroupID(ctx)
- case paymentorder.FieldSubscriptionDays:
- return m.OldSubscriptionDays(ctx)
- case paymentorder.FieldProviderInstanceID:
- return m.OldProviderInstanceID(ctx)
- case paymentorder.FieldStatus:
- return m.OldStatus(ctx)
- case paymentorder.FieldRefundAmount:
- return m.OldRefundAmount(ctx)
- case paymentorder.FieldRefundReason:
- return m.OldRefundReason(ctx)
- case paymentorder.FieldRefundAt:
- return m.OldRefundAt(ctx)
- case paymentorder.FieldForceRefund:
- return m.OldForceRefund(ctx)
- case paymentorder.FieldRefundRequestedAt:
- return m.OldRefundRequestedAt(ctx)
- case paymentorder.FieldRefundRequestReason:
- return m.OldRefundRequestReason(ctx)
- case paymentorder.FieldRefundRequestedBy:
- return m.OldRefundRequestedBy(ctx)
- case paymentorder.FieldExpiresAt:
- return m.OldExpiresAt(ctx)
- case paymentorder.FieldPaidAt:
- return m.OldPaidAt(ctx)
- case paymentorder.FieldCompletedAt:
- return m.OldCompletedAt(ctx)
- case paymentorder.FieldFailedAt:
- return m.OldFailedAt(ctx)
- case paymentorder.FieldFailedReason:
- return m.OldFailedReason(ctx)
- case paymentorder.FieldClientIP:
- return m.OldClientIP(ctx)
- case paymentorder.FieldSrcHost:
- return m.OldSrcHost(ctx)
- case paymentorder.FieldSrcURL:
- return m.OldSrcURL(ctx)
- case paymentorder.FieldCreatedAt:
- return m.OldCreatedAt(ctx)
- case paymentorder.FieldUpdatedAt:
- return m.OldUpdatedAt(ctx)
+// ClearQrCode clears the value of the "qr_code" field.
+func (m *PaymentOrderMutation) ClearQrCode() {
+ m.qr_code = nil
+ m.clearedFields[paymentorder.FieldQrCode] = struct{}{}
+}
+
+// QrCodeCleared returns if the "qr_code" field was cleared in this mutation.
+func (m *PaymentOrderMutation) QrCodeCleared() bool {
+ _, ok := m.clearedFields[paymentorder.FieldQrCode]
+ return ok
+}
+
+// ResetQrCode resets all changes to the "qr_code" field.
+func (m *PaymentOrderMutation) ResetQrCode() {
+ m.qr_code = nil
+ delete(m.clearedFields, paymentorder.FieldQrCode)
+}
+
+// SetQrCodeImg sets the "qr_code_img" field.
+func (m *PaymentOrderMutation) SetQrCodeImg(s string) {
+ m.qr_code_img = &s
+}
+
+// QrCodeImg returns the value of the "qr_code_img" field in the mutation.
+func (m *PaymentOrderMutation) QrCodeImg() (r string, exists bool) {
+ v := m.qr_code_img
+ if v == nil {
+ return
}
- return nil, fmt.Errorf("unknown PaymentOrder field %s", name)
+ return *v, true
}
-// SetField sets the value of a field with the given name. It returns an error if
-// the field is not defined in the schema, or if the type mismatched the field
-// type.
-func (m *PaymentOrderMutation) SetField(name string, value ent.Value) error {
- switch name {
- case paymentorder.FieldUserID:
- v, ok := value.(int64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetUserID(v)
- return nil
- case paymentorder.FieldUserEmail:
- v, ok := value.(string)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetUserEmail(v)
- return nil
- case paymentorder.FieldUserName:
- v, ok := value.(string)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetUserName(v)
- return nil
- case paymentorder.FieldUserNotes:
- v, ok := value.(string)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetUserNotes(v)
- return nil
- case paymentorder.FieldAmount:
- v, ok := value.(float64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetAmount(v)
- return nil
- case paymentorder.FieldPayAmount:
- v, ok := value.(float64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetPayAmount(v)
- return nil
- case paymentorder.FieldFeeRate:
- v, ok := value.(float64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetFeeRate(v)
- return nil
- case paymentorder.FieldRechargeCode:
- v, ok := value.(string)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetRechargeCode(v)
- return nil
- case paymentorder.FieldOutTradeNo:
- v, ok := value.(string)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetOutTradeNo(v)
- return nil
- case paymentorder.FieldPaymentType:
- v, ok := value.(string)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetPaymentType(v)
- return nil
- case paymentorder.FieldPaymentTradeNo:
- v, ok := value.(string)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetPaymentTradeNo(v)
- return nil
- case paymentorder.FieldPayURL:
- v, ok := value.(string)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetPayURL(v)
- return nil
- case paymentorder.FieldQrCode:
- v, ok := value.(string)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetQrCode(v)
- return nil
- case paymentorder.FieldQrCodeImg:
- v, ok := value.(string)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetQrCodeImg(v)
- return nil
- case paymentorder.FieldOrderType:
- v, ok := value.(string)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetOrderType(v)
- return nil
- case paymentorder.FieldPlanID:
- v, ok := value.(int64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetPlanID(v)
- return nil
- case paymentorder.FieldSubscriptionGroupID:
- v, ok := value.(int64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetSubscriptionGroupID(v)
- return nil
- case paymentorder.FieldSubscriptionDays:
- v, ok := value.(int)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetSubscriptionDays(v)
- return nil
- case paymentorder.FieldProviderInstanceID:
- v, ok := value.(string)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetProviderInstanceID(v)
- return nil
- case paymentorder.FieldStatus:
- v, ok := value.(string)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetStatus(v)
- return nil
- case paymentorder.FieldRefundAmount:
- v, ok := value.(float64)
+// OldQrCodeImg returns the old "qr_code_img" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldQrCodeImg(ctx context.Context) (v *string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldQrCodeImg is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldQrCodeImg requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldQrCodeImg: %w", err)
+ }
+ return oldValue.QrCodeImg, nil
+}
+
+// ClearQrCodeImg clears the value of the "qr_code_img" field.
+func (m *PaymentOrderMutation) ClearQrCodeImg() {
+ m.qr_code_img = nil
+ m.clearedFields[paymentorder.FieldQrCodeImg] = struct{}{}
+}
+
+// QrCodeImgCleared returns if the "qr_code_img" field was cleared in this mutation.
+func (m *PaymentOrderMutation) QrCodeImgCleared() bool {
+ _, ok := m.clearedFields[paymentorder.FieldQrCodeImg]
+ return ok
+}
+
+// ResetQrCodeImg resets all changes to the "qr_code_img" field.
+func (m *PaymentOrderMutation) ResetQrCodeImg() {
+ m.qr_code_img = nil
+ delete(m.clearedFields, paymentorder.FieldQrCodeImg)
+}
+
+// SetOrderType sets the "order_type" field.
+func (m *PaymentOrderMutation) SetOrderType(s string) {
+ m.order_type = &s
+}
+
+// OrderType returns the value of the "order_type" field in the mutation.
+func (m *PaymentOrderMutation) OrderType() (r string, exists bool) {
+ v := m.order_type
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldOrderType returns the old "order_type" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldOrderType(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldOrderType is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldOrderType requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldOrderType: %w", err)
+ }
+ return oldValue.OrderType, nil
+}
+
+// ResetOrderType resets all changes to the "order_type" field.
+func (m *PaymentOrderMutation) ResetOrderType() {
+ m.order_type = nil
+}
+
+// SetPlanID sets the "plan_id" field.
+func (m *PaymentOrderMutation) SetPlanID(i int64) {
+ m.plan_id = &i
+ m.addplan_id = nil
+}
+
+// PlanID returns the value of the "plan_id" field in the mutation.
+func (m *PaymentOrderMutation) PlanID() (r int64, exists bool) {
+ v := m.plan_id
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldPlanID returns the old "plan_id" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldPlanID(ctx context.Context) (v *int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldPlanID is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldPlanID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldPlanID: %w", err)
+ }
+ return oldValue.PlanID, nil
+}
+
+// AddPlanID adds i to the "plan_id" field.
+func (m *PaymentOrderMutation) AddPlanID(i int64) {
+ if m.addplan_id != nil {
+ *m.addplan_id += i
+ } else {
+ m.addplan_id = &i
+ }
+}
+
+// AddedPlanID returns the value that was added to the "plan_id" field in this mutation.
+func (m *PaymentOrderMutation) AddedPlanID() (r int64, exists bool) {
+ v := m.addplan_id
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ClearPlanID clears the value of the "plan_id" field.
+func (m *PaymentOrderMutation) ClearPlanID() {
+ m.plan_id = nil
+ m.addplan_id = nil
+ m.clearedFields[paymentorder.FieldPlanID] = struct{}{}
+}
+
+// PlanIDCleared returns if the "plan_id" field was cleared in this mutation.
+func (m *PaymentOrderMutation) PlanIDCleared() bool {
+ _, ok := m.clearedFields[paymentorder.FieldPlanID]
+ return ok
+}
+
+// ResetPlanID resets all changes to the "plan_id" field.
+func (m *PaymentOrderMutation) ResetPlanID() {
+ m.plan_id = nil
+ m.addplan_id = nil
+ delete(m.clearedFields, paymentorder.FieldPlanID)
+}
+
+// SetSubscriptionGroupID sets the "subscription_group_id" field.
+func (m *PaymentOrderMutation) SetSubscriptionGroupID(i int64) {
+ m.subscription_group_id = &i
+ m.addsubscription_group_id = nil
+}
+
+// SubscriptionGroupID returns the value of the "subscription_group_id" field in the mutation.
+func (m *PaymentOrderMutation) SubscriptionGroupID() (r int64, exists bool) {
+ v := m.subscription_group_id
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldSubscriptionGroupID returns the old "subscription_group_id" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldSubscriptionGroupID(ctx context.Context) (v *int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldSubscriptionGroupID is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldSubscriptionGroupID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldSubscriptionGroupID: %w", err)
+ }
+ return oldValue.SubscriptionGroupID, nil
+}
+
+// AddSubscriptionGroupID adds i to the "subscription_group_id" field.
+func (m *PaymentOrderMutation) AddSubscriptionGroupID(i int64) {
+ if m.addsubscription_group_id != nil {
+ *m.addsubscription_group_id += i
+ } else {
+ m.addsubscription_group_id = &i
+ }
+}
+
+// AddedSubscriptionGroupID returns the value that was added to the "subscription_group_id" field in this mutation.
+func (m *PaymentOrderMutation) AddedSubscriptionGroupID() (r int64, exists bool) {
+ v := m.addsubscription_group_id
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ClearSubscriptionGroupID clears the value of the "subscription_group_id" field.
+func (m *PaymentOrderMutation) ClearSubscriptionGroupID() {
+ m.subscription_group_id = nil
+ m.addsubscription_group_id = nil
+ m.clearedFields[paymentorder.FieldSubscriptionGroupID] = struct{}{}
+}
+
+// SubscriptionGroupIDCleared returns if the "subscription_group_id" field was cleared in this mutation.
+func (m *PaymentOrderMutation) SubscriptionGroupIDCleared() bool {
+ _, ok := m.clearedFields[paymentorder.FieldSubscriptionGroupID]
+ return ok
+}
+
+// ResetSubscriptionGroupID resets all changes to the "subscription_group_id" field.
+func (m *PaymentOrderMutation) ResetSubscriptionGroupID() {
+ m.subscription_group_id = nil
+ m.addsubscription_group_id = nil
+ delete(m.clearedFields, paymentorder.FieldSubscriptionGroupID)
+}
+
+// SetSubscriptionDays sets the "subscription_days" field.
+func (m *PaymentOrderMutation) SetSubscriptionDays(i int) {
+ m.subscription_days = &i
+ m.addsubscription_days = nil
+}
+
+// SubscriptionDays returns the value of the "subscription_days" field in the mutation.
+func (m *PaymentOrderMutation) SubscriptionDays() (r int, exists bool) {
+ v := m.subscription_days
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldSubscriptionDays returns the old "subscription_days" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldSubscriptionDays(ctx context.Context) (v *int, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldSubscriptionDays is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldSubscriptionDays requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldSubscriptionDays: %w", err)
+ }
+ return oldValue.SubscriptionDays, nil
+}
+
+// AddSubscriptionDays adds i to the "subscription_days" field.
+func (m *PaymentOrderMutation) AddSubscriptionDays(i int) {
+ if m.addsubscription_days != nil {
+ *m.addsubscription_days += i
+ } else {
+ m.addsubscription_days = &i
+ }
+}
+
+// AddedSubscriptionDays returns the value that was added to the "subscription_days" field in this mutation.
+func (m *PaymentOrderMutation) AddedSubscriptionDays() (r int, exists bool) {
+ v := m.addsubscription_days
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ClearSubscriptionDays clears the value of the "subscription_days" field.
+func (m *PaymentOrderMutation) ClearSubscriptionDays() {
+ m.subscription_days = nil
+ m.addsubscription_days = nil
+ m.clearedFields[paymentorder.FieldSubscriptionDays] = struct{}{}
+}
+
+// SubscriptionDaysCleared returns if the "subscription_days" field was cleared in this mutation.
+func (m *PaymentOrderMutation) SubscriptionDaysCleared() bool {
+ _, ok := m.clearedFields[paymentorder.FieldSubscriptionDays]
+ return ok
+}
+
+// ResetSubscriptionDays resets all changes to the "subscription_days" field.
+func (m *PaymentOrderMutation) ResetSubscriptionDays() {
+ m.subscription_days = nil
+ m.addsubscription_days = nil
+ delete(m.clearedFields, paymentorder.FieldSubscriptionDays)
+}
+
+// SetProviderInstanceID sets the "provider_instance_id" field.
+func (m *PaymentOrderMutation) SetProviderInstanceID(s string) {
+ m.provider_instance_id = &s
+}
+
+// ProviderInstanceID returns the value of the "provider_instance_id" field in the mutation.
+func (m *PaymentOrderMutation) ProviderInstanceID() (r string, exists bool) {
+ v := m.provider_instance_id
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderInstanceID returns the old "provider_instance_id" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldProviderInstanceID(ctx context.Context) (v *string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderInstanceID is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderInstanceID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderInstanceID: %w", err)
+ }
+ return oldValue.ProviderInstanceID, nil
+}
+
+// ClearProviderInstanceID clears the value of the "provider_instance_id" field.
+func (m *PaymentOrderMutation) ClearProviderInstanceID() {
+ m.provider_instance_id = nil
+ m.clearedFields[paymentorder.FieldProviderInstanceID] = struct{}{}
+}
+
+// ProviderInstanceIDCleared returns if the "provider_instance_id" field was cleared in this mutation.
+func (m *PaymentOrderMutation) ProviderInstanceIDCleared() bool {
+ _, ok := m.clearedFields[paymentorder.FieldProviderInstanceID]
+ return ok
+}
+
+// ResetProviderInstanceID resets all changes to the "provider_instance_id" field.
+func (m *PaymentOrderMutation) ResetProviderInstanceID() {
+ m.provider_instance_id = nil
+ delete(m.clearedFields, paymentorder.FieldProviderInstanceID)
+}
+
+// SetStatus sets the "status" field.
+func (m *PaymentOrderMutation) SetStatus(s string) {
+ m.status = &s
+}
+
+// Status returns the value of the "status" field in the mutation.
+func (m *PaymentOrderMutation) Status() (r string, exists bool) {
+ v := m.status
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldStatus returns the old "status" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldStatus(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldStatus is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldStatus requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldStatus: %w", err)
+ }
+ return oldValue.Status, nil
+}
+
+// ResetStatus resets all changes to the "status" field.
+func (m *PaymentOrderMutation) ResetStatus() {
+ m.status = nil
+}
+
+// SetRefundAmount sets the "refund_amount" field.
+func (m *PaymentOrderMutation) SetRefundAmount(f float64) {
+ m.refund_amount = &f
+ m.addrefund_amount = nil
+}
+
+// RefundAmount returns the value of the "refund_amount" field in the mutation.
+func (m *PaymentOrderMutation) RefundAmount() (r float64, exists bool) {
+ v := m.refund_amount
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldRefundAmount returns the old "refund_amount" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldRefundAmount(ctx context.Context) (v float64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldRefundAmount is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldRefundAmount requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldRefundAmount: %w", err)
+ }
+ return oldValue.RefundAmount, nil
+}
+
+// AddRefundAmount adds f to the "refund_amount" field.
+func (m *PaymentOrderMutation) AddRefundAmount(f float64) {
+ if m.addrefund_amount != nil {
+ *m.addrefund_amount += f
+ } else {
+ m.addrefund_amount = &f
+ }
+}
+
+// AddedRefundAmount returns the value that was added to the "refund_amount" field in this mutation.
+func (m *PaymentOrderMutation) AddedRefundAmount() (r float64, exists bool) {
+ v := m.addrefund_amount
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetRefundAmount resets all changes to the "refund_amount" field.
+func (m *PaymentOrderMutation) ResetRefundAmount() {
+ m.refund_amount = nil
+ m.addrefund_amount = nil
+}
+
+// SetRefundReason sets the "refund_reason" field.
+func (m *PaymentOrderMutation) SetRefundReason(s string) {
+ m.refund_reason = &s
+}
+
+// RefundReason returns the value of the "refund_reason" field in the mutation.
+func (m *PaymentOrderMutation) RefundReason() (r string, exists bool) {
+ v := m.refund_reason
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldRefundReason returns the old "refund_reason" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldRefundReason(ctx context.Context) (v *string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldRefundReason is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldRefundReason requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldRefundReason: %w", err)
+ }
+ return oldValue.RefundReason, nil
+}
+
+// ClearRefundReason clears the value of the "refund_reason" field.
+func (m *PaymentOrderMutation) ClearRefundReason() {
+ m.refund_reason = nil
+ m.clearedFields[paymentorder.FieldRefundReason] = struct{}{}
+}
+
+// RefundReasonCleared returns if the "refund_reason" field was cleared in this mutation.
+func (m *PaymentOrderMutation) RefundReasonCleared() bool {
+ _, ok := m.clearedFields[paymentorder.FieldRefundReason]
+ return ok
+}
+
+// ResetRefundReason resets all changes to the "refund_reason" field.
+func (m *PaymentOrderMutation) ResetRefundReason() {
+ m.refund_reason = nil
+ delete(m.clearedFields, paymentorder.FieldRefundReason)
+}
+
+// SetRefundAt sets the "refund_at" field.
+func (m *PaymentOrderMutation) SetRefundAt(t time.Time) {
+ m.refund_at = &t
+}
+
+// RefundAt returns the value of the "refund_at" field in the mutation.
+func (m *PaymentOrderMutation) RefundAt() (r time.Time, exists bool) {
+ v := m.refund_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldRefundAt returns the old "refund_at" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldRefundAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldRefundAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldRefundAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldRefundAt: %w", err)
+ }
+ return oldValue.RefundAt, nil
+}
+
+// ClearRefundAt clears the value of the "refund_at" field.
+func (m *PaymentOrderMutation) ClearRefundAt() {
+ m.refund_at = nil
+ m.clearedFields[paymentorder.FieldRefundAt] = struct{}{}
+}
+
+// RefundAtCleared returns if the "refund_at" field was cleared in this mutation.
+func (m *PaymentOrderMutation) RefundAtCleared() bool {
+ _, ok := m.clearedFields[paymentorder.FieldRefundAt]
+ return ok
+}
+
+// ResetRefundAt resets all changes to the "refund_at" field.
+func (m *PaymentOrderMutation) ResetRefundAt() {
+ m.refund_at = nil
+ delete(m.clearedFields, paymentorder.FieldRefundAt)
+}
+
+// SetForceRefund sets the "force_refund" field.
+func (m *PaymentOrderMutation) SetForceRefund(b bool) {
+ m.force_refund = &b
+}
+
+// ForceRefund returns the value of the "force_refund" field in the mutation.
+func (m *PaymentOrderMutation) ForceRefund() (r bool, exists bool) {
+ v := m.force_refund
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldForceRefund returns the old "force_refund" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldForceRefund(ctx context.Context) (v bool, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldForceRefund is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldForceRefund requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldForceRefund: %w", err)
+ }
+ return oldValue.ForceRefund, nil
+}
+
+// ResetForceRefund resets all changes to the "force_refund" field.
+func (m *PaymentOrderMutation) ResetForceRefund() {
+ m.force_refund = nil
+}
+
+// SetRefundRequestedAt sets the "refund_requested_at" field.
+func (m *PaymentOrderMutation) SetRefundRequestedAt(t time.Time) {
+ m.refund_requested_at = &t
+}
+
+// RefundRequestedAt returns the value of the "refund_requested_at" field in the mutation.
+func (m *PaymentOrderMutation) RefundRequestedAt() (r time.Time, exists bool) {
+ v := m.refund_requested_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldRefundRequestedAt returns the old "refund_requested_at" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldRefundRequestedAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldRefundRequestedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldRefundRequestedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldRefundRequestedAt: %w", err)
+ }
+ return oldValue.RefundRequestedAt, nil
+}
+
+// ClearRefundRequestedAt clears the value of the "refund_requested_at" field.
+func (m *PaymentOrderMutation) ClearRefundRequestedAt() {
+ m.refund_requested_at = nil
+ m.clearedFields[paymentorder.FieldRefundRequestedAt] = struct{}{}
+}
+
+// RefundRequestedAtCleared returns if the "refund_requested_at" field was cleared in this mutation.
+func (m *PaymentOrderMutation) RefundRequestedAtCleared() bool {
+ _, ok := m.clearedFields[paymentorder.FieldRefundRequestedAt]
+ return ok
+}
+
+// ResetRefundRequestedAt resets all changes to the "refund_requested_at" field.
+func (m *PaymentOrderMutation) ResetRefundRequestedAt() {
+ m.refund_requested_at = nil
+ delete(m.clearedFields, paymentorder.FieldRefundRequestedAt)
+}
+
+// SetRefundRequestReason sets the "refund_request_reason" field.
+func (m *PaymentOrderMutation) SetRefundRequestReason(s string) {
+ m.refund_request_reason = &s
+}
+
+// RefundRequestReason returns the value of the "refund_request_reason" field in the mutation.
+func (m *PaymentOrderMutation) RefundRequestReason() (r string, exists bool) {
+ v := m.refund_request_reason
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldRefundRequestReason returns the old "refund_request_reason" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldRefundRequestReason(ctx context.Context) (v *string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldRefundRequestReason is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldRefundRequestReason requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldRefundRequestReason: %w", err)
+ }
+ return oldValue.RefundRequestReason, nil
+}
+
+// ClearRefundRequestReason clears the value of the "refund_request_reason" field.
+func (m *PaymentOrderMutation) ClearRefundRequestReason() {
+ m.refund_request_reason = nil
+ m.clearedFields[paymentorder.FieldRefundRequestReason] = struct{}{}
+}
+
+// RefundRequestReasonCleared returns if the "refund_request_reason" field was cleared in this mutation.
+func (m *PaymentOrderMutation) RefundRequestReasonCleared() bool {
+ _, ok := m.clearedFields[paymentorder.FieldRefundRequestReason]
+ return ok
+}
+
+// ResetRefundRequestReason resets all changes to the "refund_request_reason" field.
+func (m *PaymentOrderMutation) ResetRefundRequestReason() {
+ m.refund_request_reason = nil
+ delete(m.clearedFields, paymentorder.FieldRefundRequestReason)
+}
+
+// SetRefundRequestedBy sets the "refund_requested_by" field.
+func (m *PaymentOrderMutation) SetRefundRequestedBy(s string) {
+ m.refund_requested_by = &s
+}
+
+// RefundRequestedBy returns the value of the "refund_requested_by" field in the mutation.
+func (m *PaymentOrderMutation) RefundRequestedBy() (r string, exists bool) {
+ v := m.refund_requested_by
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldRefundRequestedBy returns the old "refund_requested_by" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldRefundRequestedBy(ctx context.Context) (v *string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldRefundRequestedBy is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldRefundRequestedBy requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldRefundRequestedBy: %w", err)
+ }
+ return oldValue.RefundRequestedBy, nil
+}
+
+// ClearRefundRequestedBy clears the value of the "refund_requested_by" field.
+func (m *PaymentOrderMutation) ClearRefundRequestedBy() {
+ m.refund_requested_by = nil
+ m.clearedFields[paymentorder.FieldRefundRequestedBy] = struct{}{}
+}
+
+// RefundRequestedByCleared returns if the "refund_requested_by" field was cleared in this mutation.
+func (m *PaymentOrderMutation) RefundRequestedByCleared() bool {
+ _, ok := m.clearedFields[paymentorder.FieldRefundRequestedBy]
+ return ok
+}
+
+// ResetRefundRequestedBy resets all changes to the "refund_requested_by" field.
+func (m *PaymentOrderMutation) ResetRefundRequestedBy() {
+ m.refund_requested_by = nil
+ delete(m.clearedFields, paymentorder.FieldRefundRequestedBy)
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (m *PaymentOrderMutation) SetExpiresAt(t time.Time) {
+ m.expires_at = &t
+}
+
+// ExpiresAt returns the value of the "expires_at" field in the mutation.
+func (m *PaymentOrderMutation) ExpiresAt() (r time.Time, exists bool) {
+ v := m.expires_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldExpiresAt returns the old "expires_at" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldExpiresAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldExpiresAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldExpiresAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldExpiresAt: %w", err)
+ }
+ return oldValue.ExpiresAt, nil
+}
+
+// ResetExpiresAt resets all changes to the "expires_at" field.
+func (m *PaymentOrderMutation) ResetExpiresAt() {
+ m.expires_at = nil
+}
+
+// SetPaidAt sets the "paid_at" field.
+func (m *PaymentOrderMutation) SetPaidAt(t time.Time) {
+ m.paid_at = &t
+}
+
+// PaidAt returns the value of the "paid_at" field in the mutation.
+func (m *PaymentOrderMutation) PaidAt() (r time.Time, exists bool) {
+ v := m.paid_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldPaidAt returns the old "paid_at" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldPaidAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldPaidAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldPaidAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldPaidAt: %w", err)
+ }
+ return oldValue.PaidAt, nil
+}
+
+// ClearPaidAt clears the value of the "paid_at" field.
+func (m *PaymentOrderMutation) ClearPaidAt() {
+ m.paid_at = nil
+ m.clearedFields[paymentorder.FieldPaidAt] = struct{}{}
+}
+
+// PaidAtCleared returns if the "paid_at" field was cleared in this mutation.
+func (m *PaymentOrderMutation) PaidAtCleared() bool {
+ _, ok := m.clearedFields[paymentorder.FieldPaidAt]
+ return ok
+}
+
+// ResetPaidAt resets all changes to the "paid_at" field.
+func (m *PaymentOrderMutation) ResetPaidAt() {
+ m.paid_at = nil
+ delete(m.clearedFields, paymentorder.FieldPaidAt)
+}
+
+// SetCompletedAt sets the "completed_at" field.
+func (m *PaymentOrderMutation) SetCompletedAt(t time.Time) {
+ m.completed_at = &t
+}
+
+// CompletedAt returns the value of the "completed_at" field in the mutation.
+func (m *PaymentOrderMutation) CompletedAt() (r time.Time, exists bool) {
+ v := m.completed_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCompletedAt returns the old "completed_at" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldCompletedAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCompletedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCompletedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCompletedAt: %w", err)
+ }
+ return oldValue.CompletedAt, nil
+}
+
+// ClearCompletedAt clears the value of the "completed_at" field.
+func (m *PaymentOrderMutation) ClearCompletedAt() {
+ m.completed_at = nil
+ m.clearedFields[paymentorder.FieldCompletedAt] = struct{}{}
+}
+
+// CompletedAtCleared returns if the "completed_at" field was cleared in this mutation.
+func (m *PaymentOrderMutation) CompletedAtCleared() bool {
+ _, ok := m.clearedFields[paymentorder.FieldCompletedAt]
+ return ok
+}
+
+// ResetCompletedAt resets all changes to the "completed_at" field.
+func (m *PaymentOrderMutation) ResetCompletedAt() {
+ m.completed_at = nil
+ delete(m.clearedFields, paymentorder.FieldCompletedAt)
+}
+
+// SetFailedAt sets the "failed_at" field.
+func (m *PaymentOrderMutation) SetFailedAt(t time.Time) {
+ m.failed_at = &t
+}
+
+// FailedAt returns the value of the "failed_at" field in the mutation.
+func (m *PaymentOrderMutation) FailedAt() (r time.Time, exists bool) {
+ v := m.failed_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldFailedAt returns the old "failed_at" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldFailedAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldFailedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldFailedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldFailedAt: %w", err)
+ }
+ return oldValue.FailedAt, nil
+}
+
+// ClearFailedAt clears the value of the "failed_at" field.
+func (m *PaymentOrderMutation) ClearFailedAt() {
+ m.failed_at = nil
+ m.clearedFields[paymentorder.FieldFailedAt] = struct{}{}
+}
+
+// FailedAtCleared returns if the "failed_at" field was cleared in this mutation.
+func (m *PaymentOrderMutation) FailedAtCleared() bool {
+ _, ok := m.clearedFields[paymentorder.FieldFailedAt]
+ return ok
+}
+
+// ResetFailedAt resets all changes to the "failed_at" field.
+func (m *PaymentOrderMutation) ResetFailedAt() {
+ m.failed_at = nil
+ delete(m.clearedFields, paymentorder.FieldFailedAt)
+}
+
+// SetFailedReason sets the "failed_reason" field.
+func (m *PaymentOrderMutation) SetFailedReason(s string) {
+ m.failed_reason = &s
+}
+
+// FailedReason returns the value of the "failed_reason" field in the mutation.
+func (m *PaymentOrderMutation) FailedReason() (r string, exists bool) {
+ v := m.failed_reason
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldFailedReason returns the old "failed_reason" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldFailedReason(ctx context.Context) (v *string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldFailedReason is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldFailedReason requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldFailedReason: %w", err)
+ }
+ return oldValue.FailedReason, nil
+}
+
+// ClearFailedReason clears the value of the "failed_reason" field.
+func (m *PaymentOrderMutation) ClearFailedReason() {
+ m.failed_reason = nil
+ m.clearedFields[paymentorder.FieldFailedReason] = struct{}{}
+}
+
+// FailedReasonCleared returns if the "failed_reason" field was cleared in this mutation.
+func (m *PaymentOrderMutation) FailedReasonCleared() bool {
+ _, ok := m.clearedFields[paymentorder.FieldFailedReason]
+ return ok
+}
+
+// ResetFailedReason resets all changes to the "failed_reason" field.
+func (m *PaymentOrderMutation) ResetFailedReason() {
+ m.failed_reason = nil
+ delete(m.clearedFields, paymentorder.FieldFailedReason)
+}
+
+// SetClientIP sets the "client_ip" field.
+func (m *PaymentOrderMutation) SetClientIP(s string) {
+ m.client_ip = &s
+}
+
+// ClientIP returns the value of the "client_ip" field in the mutation.
+func (m *PaymentOrderMutation) ClientIP() (r string, exists bool) {
+ v := m.client_ip
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldClientIP returns the old "client_ip" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldClientIP(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldClientIP is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldClientIP requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldClientIP: %w", err)
+ }
+ return oldValue.ClientIP, nil
+}
+
+// ResetClientIP resets all changes to the "client_ip" field.
+func (m *PaymentOrderMutation) ResetClientIP() {
+ m.client_ip = nil
+}
+
+// SetSrcHost sets the "src_host" field.
+func (m *PaymentOrderMutation) SetSrcHost(s string) {
+ m.src_host = &s
+}
+
+// SrcHost returns the value of the "src_host" field in the mutation.
+func (m *PaymentOrderMutation) SrcHost() (r string, exists bool) {
+ v := m.src_host
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldSrcHost returns the old "src_host" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldSrcHost(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldSrcHost is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldSrcHost requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldSrcHost: %w", err)
+ }
+ return oldValue.SrcHost, nil
+}
+
+// ResetSrcHost resets all changes to the "src_host" field.
+func (m *PaymentOrderMutation) ResetSrcHost() {
+ m.src_host = nil
+}
+
+// SetSrcURL sets the "src_url" field.
+func (m *PaymentOrderMutation) SetSrcURL(s string) {
+ m.src_url = &s
+}
+
+// SrcURL returns the value of the "src_url" field in the mutation.
+func (m *PaymentOrderMutation) SrcURL() (r string, exists bool) {
+ v := m.src_url
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldSrcURL returns the old "src_url" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldSrcURL(ctx context.Context) (v *string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldSrcURL is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldSrcURL requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldSrcURL: %w", err)
+ }
+ return oldValue.SrcURL, nil
+}
+
+// ClearSrcURL clears the value of the "src_url" field.
+func (m *PaymentOrderMutation) ClearSrcURL() {
+ m.src_url = nil
+ m.clearedFields[paymentorder.FieldSrcURL] = struct{}{}
+}
+
+// SrcURLCleared returns if the "src_url" field was cleared in this mutation.
+func (m *PaymentOrderMutation) SrcURLCleared() bool {
+ _, ok := m.clearedFields[paymentorder.FieldSrcURL]
+ return ok
+}
+
+// ResetSrcURL resets all changes to the "src_url" field.
+func (m *PaymentOrderMutation) ResetSrcURL() {
+ m.src_url = nil
+ delete(m.clearedFields, paymentorder.FieldSrcURL)
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (m *PaymentOrderMutation) SetCreatedAt(t time.Time) {
+ m.created_at = &t
+}
+
+// CreatedAt returns the value of the "created_at" field in the mutation.
+func (m *PaymentOrderMutation) CreatedAt() (r time.Time, exists bool) {
+ v := m.created_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCreatedAt returns the old "created_at" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCreatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err)
+ }
+ return oldValue.CreatedAt, nil
+}
+
+// ResetCreatedAt resets all changes to the "created_at" field.
+func (m *PaymentOrderMutation) ResetCreatedAt() {
+ m.created_at = nil
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (m *PaymentOrderMutation) SetUpdatedAt(t time.Time) {
+ m.updated_at = &t
+}
+
+// UpdatedAt returns the value of the "updated_at" field in the mutation.
+func (m *PaymentOrderMutation) UpdatedAt() (r time.Time, exists bool) {
+ v := m.updated_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldUpdatedAt returns the old "updated_at" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldUpdatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err)
+ }
+ return oldValue.UpdatedAt, nil
+}
+
+// ResetUpdatedAt resets all changes to the "updated_at" field.
+func (m *PaymentOrderMutation) ResetUpdatedAt() {
+ m.updated_at = nil
+}
+
+// ClearUser clears the "user" edge to the User entity.
+func (m *PaymentOrderMutation) ClearUser() {
+ m.cleareduser = true
+ m.clearedFields[paymentorder.FieldUserID] = struct{}{}
+}
+
+// UserCleared reports if the "user" edge to the User entity was cleared.
+func (m *PaymentOrderMutation) UserCleared() bool {
+ return m.cleareduser
+}
+
+// UserIDs returns the "user" edge IDs in the mutation.
+// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use
+// UserID instead. It exists only for internal usage by the builders.
+func (m *PaymentOrderMutation) UserIDs() (ids []int64) {
+ if id := m.user; id != nil {
+ ids = append(ids, *id)
+ }
+ return
+}
+
+// ResetUser resets all changes to the "user" edge.
+func (m *PaymentOrderMutation) ResetUser() {
+ m.user = nil
+ m.cleareduser = false
+}
+
+// Where appends a list predicates to the PaymentOrderMutation builder.
+func (m *PaymentOrderMutation) Where(ps ...predicate.PaymentOrder) {
+ m.predicates = append(m.predicates, ps...)
+}
+
+// WhereP appends storage-level predicates to the PaymentOrderMutation builder. Using this method,
+// users can use type-assertion to append predicates that do not depend on any generated package.
+func (m *PaymentOrderMutation) WhereP(ps ...func(*sql.Selector)) {
+ p := make([]predicate.PaymentOrder, len(ps))
+ for i := range ps {
+ p[i] = ps[i]
+ }
+ m.Where(p...)
+}
+
+// Op returns the operation name.
+func (m *PaymentOrderMutation) Op() Op {
+ return m.op
+}
+
+// SetOp allows setting the mutation operation.
+func (m *PaymentOrderMutation) SetOp(op Op) {
+ m.op = op
+}
+
+// Type returns the node type of this mutation (PaymentOrder).
+func (m *PaymentOrderMutation) Type() string {
+ return m.typ
+}
+
+// Fields returns all fields that were changed during this mutation. Note that in
+// order to get all numeric fields that were incremented/decremented, call
+// AddedFields().
+func (m *PaymentOrderMutation) Fields() []string {
+ fields := make([]string, 0, 37)
+ if m.user != nil {
+ fields = append(fields, paymentorder.FieldUserID)
+ }
+ if m.user_email != nil {
+ fields = append(fields, paymentorder.FieldUserEmail)
+ }
+ if m.user_name != nil {
+ fields = append(fields, paymentorder.FieldUserName)
+ }
+ if m.user_notes != nil {
+ fields = append(fields, paymentorder.FieldUserNotes)
+ }
+ if m.amount != nil {
+ fields = append(fields, paymentorder.FieldAmount)
+ }
+ if m.pay_amount != nil {
+ fields = append(fields, paymentorder.FieldPayAmount)
+ }
+ if m.fee_rate != nil {
+ fields = append(fields, paymentorder.FieldFeeRate)
+ }
+ if m.recharge_code != nil {
+ fields = append(fields, paymentorder.FieldRechargeCode)
+ }
+ if m.out_trade_no != nil {
+ fields = append(fields, paymentorder.FieldOutTradeNo)
+ }
+ if m.payment_type != nil {
+ fields = append(fields, paymentorder.FieldPaymentType)
+ }
+ if m.payment_trade_no != nil {
+ fields = append(fields, paymentorder.FieldPaymentTradeNo)
+ }
+ if m.pay_url != nil {
+ fields = append(fields, paymentorder.FieldPayURL)
+ }
+ if m.qr_code != nil {
+ fields = append(fields, paymentorder.FieldQrCode)
+ }
+ if m.qr_code_img != nil {
+ fields = append(fields, paymentorder.FieldQrCodeImg)
+ }
+ if m.order_type != nil {
+ fields = append(fields, paymentorder.FieldOrderType)
+ }
+ if m.plan_id != nil {
+ fields = append(fields, paymentorder.FieldPlanID)
+ }
+ if m.subscription_group_id != nil {
+ fields = append(fields, paymentorder.FieldSubscriptionGroupID)
+ }
+ if m.subscription_days != nil {
+ fields = append(fields, paymentorder.FieldSubscriptionDays)
+ }
+ if m.provider_instance_id != nil {
+ fields = append(fields, paymentorder.FieldProviderInstanceID)
+ }
+ if m.status != nil {
+ fields = append(fields, paymentorder.FieldStatus)
+ }
+ if m.refund_amount != nil {
+ fields = append(fields, paymentorder.FieldRefundAmount)
+ }
+ if m.refund_reason != nil {
+ fields = append(fields, paymentorder.FieldRefundReason)
+ }
+ if m.refund_at != nil {
+ fields = append(fields, paymentorder.FieldRefundAt)
+ }
+ if m.force_refund != nil {
+ fields = append(fields, paymentorder.FieldForceRefund)
+ }
+ if m.refund_requested_at != nil {
+ fields = append(fields, paymentorder.FieldRefundRequestedAt)
+ }
+ if m.refund_request_reason != nil {
+ fields = append(fields, paymentorder.FieldRefundRequestReason)
+ }
+ if m.refund_requested_by != nil {
+ fields = append(fields, paymentorder.FieldRefundRequestedBy)
+ }
+ if m.expires_at != nil {
+ fields = append(fields, paymentorder.FieldExpiresAt)
+ }
+ if m.paid_at != nil {
+ fields = append(fields, paymentorder.FieldPaidAt)
+ }
+ if m.completed_at != nil {
+ fields = append(fields, paymentorder.FieldCompletedAt)
+ }
+ if m.failed_at != nil {
+ fields = append(fields, paymentorder.FieldFailedAt)
+ }
+ if m.failed_reason != nil {
+ fields = append(fields, paymentorder.FieldFailedReason)
+ }
+ if m.client_ip != nil {
+ fields = append(fields, paymentorder.FieldClientIP)
+ }
+ if m.src_host != nil {
+ fields = append(fields, paymentorder.FieldSrcHost)
+ }
+ if m.src_url != nil {
+ fields = append(fields, paymentorder.FieldSrcURL)
+ }
+ if m.created_at != nil {
+ fields = append(fields, paymentorder.FieldCreatedAt)
+ }
+ if m.updated_at != nil {
+ fields = append(fields, paymentorder.FieldUpdatedAt)
+ }
+ return fields
+}
+
+// Field returns the value of a field with the given name. The second boolean
+// return value indicates that this field was not set, or was not defined in the
+// schema.
+func (m *PaymentOrderMutation) Field(name string) (ent.Value, bool) {
+ switch name {
+ case paymentorder.FieldUserID:
+ return m.UserID()
+ case paymentorder.FieldUserEmail:
+ return m.UserEmail()
+ case paymentorder.FieldUserName:
+ return m.UserName()
+ case paymentorder.FieldUserNotes:
+ return m.UserNotes()
+ case paymentorder.FieldAmount:
+ return m.Amount()
+ case paymentorder.FieldPayAmount:
+ return m.PayAmount()
+ case paymentorder.FieldFeeRate:
+ return m.FeeRate()
+ case paymentorder.FieldRechargeCode:
+ return m.RechargeCode()
+ case paymentorder.FieldOutTradeNo:
+ return m.OutTradeNo()
+ case paymentorder.FieldPaymentType:
+ return m.PaymentType()
+ case paymentorder.FieldPaymentTradeNo:
+ return m.PaymentTradeNo()
+ case paymentorder.FieldPayURL:
+ return m.PayURL()
+ case paymentorder.FieldQrCode:
+ return m.QrCode()
+ case paymentorder.FieldQrCodeImg:
+ return m.QrCodeImg()
+ case paymentorder.FieldOrderType:
+ return m.OrderType()
+ case paymentorder.FieldPlanID:
+ return m.PlanID()
+ case paymentorder.FieldSubscriptionGroupID:
+ return m.SubscriptionGroupID()
+ case paymentorder.FieldSubscriptionDays:
+ return m.SubscriptionDays()
+ case paymentorder.FieldProviderInstanceID:
+ return m.ProviderInstanceID()
+ case paymentorder.FieldStatus:
+ return m.Status()
+ case paymentorder.FieldRefundAmount:
+ return m.RefundAmount()
+ case paymentorder.FieldRefundReason:
+ return m.RefundReason()
+ case paymentorder.FieldRefundAt:
+ return m.RefundAt()
+ case paymentorder.FieldForceRefund:
+ return m.ForceRefund()
+ case paymentorder.FieldRefundRequestedAt:
+ return m.RefundRequestedAt()
+ case paymentorder.FieldRefundRequestReason:
+ return m.RefundRequestReason()
+ case paymentorder.FieldRefundRequestedBy:
+ return m.RefundRequestedBy()
+ case paymentorder.FieldExpiresAt:
+ return m.ExpiresAt()
+ case paymentorder.FieldPaidAt:
+ return m.PaidAt()
+ case paymentorder.FieldCompletedAt:
+ return m.CompletedAt()
+ case paymentorder.FieldFailedAt:
+ return m.FailedAt()
+ case paymentorder.FieldFailedReason:
+ return m.FailedReason()
+ case paymentorder.FieldClientIP:
+ return m.ClientIP()
+ case paymentorder.FieldSrcHost:
+ return m.SrcHost()
+ case paymentorder.FieldSrcURL:
+ return m.SrcURL()
+ case paymentorder.FieldCreatedAt:
+ return m.CreatedAt()
+ case paymentorder.FieldUpdatedAt:
+ return m.UpdatedAt()
+ }
+ return nil, false
+}
+
+// OldField returns the old value of the field from the database. An error is
+// returned if the mutation operation is not UpdateOne, or the query to the
+// database failed.
+func (m *PaymentOrderMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
+ switch name {
+ case paymentorder.FieldUserID:
+ return m.OldUserID(ctx)
+ case paymentorder.FieldUserEmail:
+ return m.OldUserEmail(ctx)
+ case paymentorder.FieldUserName:
+ return m.OldUserName(ctx)
+ case paymentorder.FieldUserNotes:
+ return m.OldUserNotes(ctx)
+ case paymentorder.FieldAmount:
+ return m.OldAmount(ctx)
+ case paymentorder.FieldPayAmount:
+ return m.OldPayAmount(ctx)
+ case paymentorder.FieldFeeRate:
+ return m.OldFeeRate(ctx)
+ case paymentorder.FieldRechargeCode:
+ return m.OldRechargeCode(ctx)
+ case paymentorder.FieldOutTradeNo:
+ return m.OldOutTradeNo(ctx)
+ case paymentorder.FieldPaymentType:
+ return m.OldPaymentType(ctx)
+ case paymentorder.FieldPaymentTradeNo:
+ return m.OldPaymentTradeNo(ctx)
+ case paymentorder.FieldPayURL:
+ return m.OldPayURL(ctx)
+ case paymentorder.FieldQrCode:
+ return m.OldQrCode(ctx)
+ case paymentorder.FieldQrCodeImg:
+ return m.OldQrCodeImg(ctx)
+ case paymentorder.FieldOrderType:
+ return m.OldOrderType(ctx)
+ case paymentorder.FieldPlanID:
+ return m.OldPlanID(ctx)
+ case paymentorder.FieldSubscriptionGroupID:
+ return m.OldSubscriptionGroupID(ctx)
+ case paymentorder.FieldSubscriptionDays:
+ return m.OldSubscriptionDays(ctx)
+ case paymentorder.FieldProviderInstanceID:
+ return m.OldProviderInstanceID(ctx)
+ case paymentorder.FieldStatus:
+ return m.OldStatus(ctx)
+ case paymentorder.FieldRefundAmount:
+ return m.OldRefundAmount(ctx)
+ case paymentorder.FieldRefundReason:
+ return m.OldRefundReason(ctx)
+ case paymentorder.FieldRefundAt:
+ return m.OldRefundAt(ctx)
+ case paymentorder.FieldForceRefund:
+ return m.OldForceRefund(ctx)
+ case paymentorder.FieldRefundRequestedAt:
+ return m.OldRefundRequestedAt(ctx)
+ case paymentorder.FieldRefundRequestReason:
+ return m.OldRefundRequestReason(ctx)
+ case paymentorder.FieldRefundRequestedBy:
+ return m.OldRefundRequestedBy(ctx)
+ case paymentorder.FieldExpiresAt:
+ return m.OldExpiresAt(ctx)
+ case paymentorder.FieldPaidAt:
+ return m.OldPaidAt(ctx)
+ case paymentorder.FieldCompletedAt:
+ return m.OldCompletedAt(ctx)
+ case paymentorder.FieldFailedAt:
+ return m.OldFailedAt(ctx)
+ case paymentorder.FieldFailedReason:
+ return m.OldFailedReason(ctx)
+ case paymentorder.FieldClientIP:
+ return m.OldClientIP(ctx)
+ case paymentorder.FieldSrcHost:
+ return m.OldSrcHost(ctx)
+ case paymentorder.FieldSrcURL:
+ return m.OldSrcURL(ctx)
+ case paymentorder.FieldCreatedAt:
+ return m.OldCreatedAt(ctx)
+ case paymentorder.FieldUpdatedAt:
+ return m.OldUpdatedAt(ctx)
+ }
+ return nil, fmt.Errorf("unknown PaymentOrder field %s", name)
+}
+
+// SetField sets the value of a field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *PaymentOrderMutation) SetField(name string, value ent.Value) error {
+ switch name {
+ case paymentorder.FieldUserID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUserID(v)
+ return nil
+ case paymentorder.FieldUserEmail:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUserEmail(v)
+ return nil
+ case paymentorder.FieldUserName:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUserName(v)
+ return nil
+ case paymentorder.FieldUserNotes:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUserNotes(v)
+ return nil
+ case paymentorder.FieldAmount:
+ v, ok := value.(float64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetAmount(v)
+ return nil
+ case paymentorder.FieldPayAmount:
+ v, ok := value.(float64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetPayAmount(v)
+ return nil
+ case paymentorder.FieldFeeRate:
+ v, ok := value.(float64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetFeeRate(v)
+ return nil
+ case paymentorder.FieldRechargeCode:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetRechargeCode(v)
+ return nil
+ case paymentorder.FieldOutTradeNo:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetOutTradeNo(v)
+ return nil
+ case paymentorder.FieldPaymentType:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetPaymentType(v)
+ return nil
+ case paymentorder.FieldPaymentTradeNo:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetPaymentTradeNo(v)
+ return nil
+ case paymentorder.FieldPayURL:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetPayURL(v)
+ return nil
+ case paymentorder.FieldQrCode:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetQrCode(v)
+ return nil
+ case paymentorder.FieldQrCodeImg:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetQrCodeImg(v)
+ return nil
+ case paymentorder.FieldOrderType:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetOrderType(v)
+ return nil
+ case paymentorder.FieldPlanID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetPlanID(v)
+ return nil
+ case paymentorder.FieldSubscriptionGroupID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetSubscriptionGroupID(v)
+ return nil
+ case paymentorder.FieldSubscriptionDays:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetSubscriptionDays(v)
+ return nil
+ case paymentorder.FieldProviderInstanceID:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderInstanceID(v)
+ return nil
+ case paymentorder.FieldStatus:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetStatus(v)
+ return nil
+ case paymentorder.FieldRefundAmount:
+ v, ok := value.(float64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetRefundAmount(v)
+ return nil
+ case paymentorder.FieldRefundReason:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetRefundReason(v)
+ return nil
+ case paymentorder.FieldRefundAt:
+ v, ok := value.(time.Time)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
- m.SetRefundAmount(v)
+ m.SetRefundAt(v)
+ return nil
+ case paymentorder.FieldForceRefund:
+ v, ok := value.(bool)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetForceRefund(v)
+ return nil
+ case paymentorder.FieldRefundRequestedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetRefundRequestedAt(v)
+ return nil
+ case paymentorder.FieldRefundRequestReason:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetRefundRequestReason(v)
+ return nil
+ case paymentorder.FieldRefundRequestedBy:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetRefundRequestedBy(v)
+ return nil
+ case paymentorder.FieldExpiresAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetExpiresAt(v)
+ return nil
+ case paymentorder.FieldPaidAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetPaidAt(v)
+ return nil
+ case paymentorder.FieldCompletedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCompletedAt(v)
+ return nil
+ case paymentorder.FieldFailedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetFailedAt(v)
+ return nil
+ case paymentorder.FieldFailedReason:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetFailedReason(v)
+ return nil
+ case paymentorder.FieldClientIP:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetClientIP(v)
+ return nil
+ case paymentorder.FieldSrcHost:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetSrcHost(v)
+ return nil
+ case paymentorder.FieldSrcURL:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetSrcURL(v)
+ return nil
+ case paymentorder.FieldCreatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCreatedAt(v)
+ return nil
+ case paymentorder.FieldUpdatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUpdatedAt(v)
+ return nil
+ }
+ return fmt.Errorf("unknown PaymentOrder field %s", name)
+}
+
+// AddedFields returns all numeric fields that were incremented/decremented during
+// this mutation.
+func (m *PaymentOrderMutation) AddedFields() []string {
+ var fields []string
+ if m.addamount != nil {
+ fields = append(fields, paymentorder.FieldAmount)
+ }
+ if m.addpay_amount != nil {
+ fields = append(fields, paymentorder.FieldPayAmount)
+ }
+ if m.addfee_rate != nil {
+ fields = append(fields, paymentorder.FieldFeeRate)
+ }
+ if m.addplan_id != nil {
+ fields = append(fields, paymentorder.FieldPlanID)
+ }
+ if m.addsubscription_group_id != nil {
+ fields = append(fields, paymentorder.FieldSubscriptionGroupID)
+ }
+ if m.addsubscription_days != nil {
+ fields = append(fields, paymentorder.FieldSubscriptionDays)
+ }
+ if m.addrefund_amount != nil {
+ fields = append(fields, paymentorder.FieldRefundAmount)
+ }
+ return fields
+}
+
+// AddedField returns the numeric value that was incremented/decremented on a field
+// with the given name. The second boolean return value indicates that this field
+// was not set, or was not defined in the schema.
+func (m *PaymentOrderMutation) AddedField(name string) (ent.Value, bool) {
+ switch name {
+ case paymentorder.FieldAmount:
+ return m.AddedAmount()
+ case paymentorder.FieldPayAmount:
+ return m.AddedPayAmount()
+ case paymentorder.FieldFeeRate:
+ return m.AddedFeeRate()
+ case paymentorder.FieldPlanID:
+ return m.AddedPlanID()
+ case paymentorder.FieldSubscriptionGroupID:
+ return m.AddedSubscriptionGroupID()
+ case paymentorder.FieldSubscriptionDays:
+ return m.AddedSubscriptionDays()
+ case paymentorder.FieldRefundAmount:
+ return m.AddedRefundAmount()
+ }
+ return nil, false
+}
+
+// AddField adds the value to the field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *PaymentOrderMutation) AddField(name string, value ent.Value) error {
+ switch name {
+ case paymentorder.FieldAmount:
+ v, ok := value.(float64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddAmount(v)
+ return nil
+ case paymentorder.FieldPayAmount:
+ v, ok := value.(float64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddPayAmount(v)
+ return nil
+ case paymentorder.FieldFeeRate:
+ v, ok := value.(float64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddFeeRate(v)
+ return nil
+ case paymentorder.FieldPlanID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddPlanID(v)
+ return nil
+ case paymentorder.FieldSubscriptionGroupID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddSubscriptionGroupID(v)
+ return nil
+ case paymentorder.FieldSubscriptionDays:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddSubscriptionDays(v)
+ return nil
+ case paymentorder.FieldRefundAmount:
+ v, ok := value.(float64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddRefundAmount(v)
+ return nil
+ }
+ return fmt.Errorf("unknown PaymentOrder numeric field %s", name)
+}
+
+// ClearedFields returns all nullable fields that were cleared during this
+// mutation.
+func (m *PaymentOrderMutation) ClearedFields() []string {
+ var fields []string
+ if m.FieldCleared(paymentorder.FieldUserNotes) {
+ fields = append(fields, paymentorder.FieldUserNotes)
+ }
+ if m.FieldCleared(paymentorder.FieldPayURL) {
+ fields = append(fields, paymentorder.FieldPayURL)
+ }
+ if m.FieldCleared(paymentorder.FieldQrCode) {
+ fields = append(fields, paymentorder.FieldQrCode)
+ }
+ if m.FieldCleared(paymentorder.FieldQrCodeImg) {
+ fields = append(fields, paymentorder.FieldQrCodeImg)
+ }
+ if m.FieldCleared(paymentorder.FieldPlanID) {
+ fields = append(fields, paymentorder.FieldPlanID)
+ }
+ if m.FieldCleared(paymentorder.FieldSubscriptionGroupID) {
+ fields = append(fields, paymentorder.FieldSubscriptionGroupID)
+ }
+ if m.FieldCleared(paymentorder.FieldSubscriptionDays) {
+ fields = append(fields, paymentorder.FieldSubscriptionDays)
+ }
+ if m.FieldCleared(paymentorder.FieldProviderInstanceID) {
+ fields = append(fields, paymentorder.FieldProviderInstanceID)
+ }
+ if m.FieldCleared(paymentorder.FieldRefundReason) {
+ fields = append(fields, paymentorder.FieldRefundReason)
+ }
+ if m.FieldCleared(paymentorder.FieldRefundAt) {
+ fields = append(fields, paymentorder.FieldRefundAt)
+ }
+ if m.FieldCleared(paymentorder.FieldRefundRequestedAt) {
+ fields = append(fields, paymentorder.FieldRefundRequestedAt)
+ }
+ if m.FieldCleared(paymentorder.FieldRefundRequestReason) {
+ fields = append(fields, paymentorder.FieldRefundRequestReason)
+ }
+ if m.FieldCleared(paymentorder.FieldRefundRequestedBy) {
+ fields = append(fields, paymentorder.FieldRefundRequestedBy)
+ }
+ if m.FieldCleared(paymentorder.FieldPaidAt) {
+ fields = append(fields, paymentorder.FieldPaidAt)
+ }
+ if m.FieldCleared(paymentorder.FieldCompletedAt) {
+ fields = append(fields, paymentorder.FieldCompletedAt)
+ }
+ if m.FieldCleared(paymentorder.FieldFailedAt) {
+ fields = append(fields, paymentorder.FieldFailedAt)
+ }
+ if m.FieldCleared(paymentorder.FieldFailedReason) {
+ fields = append(fields, paymentorder.FieldFailedReason)
+ }
+ if m.FieldCleared(paymentorder.FieldSrcURL) {
+ fields = append(fields, paymentorder.FieldSrcURL)
+ }
+ return fields
+}
+
+// FieldCleared returns a boolean indicating if a field with the given name was
+// cleared in this mutation.
+func (m *PaymentOrderMutation) FieldCleared(name string) bool {
+ _, ok := m.clearedFields[name]
+ return ok
+}
+
+// ClearField clears the value of the field with the given name. It returns an
+// error if the field is not defined in the schema.
+func (m *PaymentOrderMutation) ClearField(name string) error {
+ switch name {
+ case paymentorder.FieldUserNotes:
+ m.ClearUserNotes()
+ return nil
+ case paymentorder.FieldPayURL:
+ m.ClearPayURL()
+ return nil
+ case paymentorder.FieldQrCode:
+ m.ClearQrCode()
+ return nil
+ case paymentorder.FieldQrCodeImg:
+ m.ClearQrCodeImg()
+ return nil
+ case paymentorder.FieldPlanID:
+ m.ClearPlanID()
+ return nil
+ case paymentorder.FieldSubscriptionGroupID:
+ m.ClearSubscriptionGroupID()
+ return nil
+ case paymentorder.FieldSubscriptionDays:
+ m.ClearSubscriptionDays()
+ return nil
+ case paymentorder.FieldProviderInstanceID:
+ m.ClearProviderInstanceID()
return nil
case paymentorder.FieldRefundReason:
- v, ok := value.(string)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetRefundReason(v)
+ m.ClearRefundReason()
return nil
case paymentorder.FieldRefundAt:
- v, ok := value.(time.Time)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetRefundAt(v)
+ m.ClearRefundAt()
+ return nil
+ case paymentorder.FieldRefundRequestedAt:
+ m.ClearRefundRequestedAt()
+ return nil
+ case paymentorder.FieldRefundRequestReason:
+ m.ClearRefundRequestReason()
+ return nil
+ case paymentorder.FieldRefundRequestedBy:
+ m.ClearRefundRequestedBy()
+ return nil
+ case paymentorder.FieldPaidAt:
+ m.ClearPaidAt()
+ return nil
+ case paymentorder.FieldCompletedAt:
+ m.ClearCompletedAt()
+ return nil
+ case paymentorder.FieldFailedAt:
+ m.ClearFailedAt()
+ return nil
+ case paymentorder.FieldFailedReason:
+ m.ClearFailedReason()
+ return nil
+ case paymentorder.FieldSrcURL:
+ m.ClearSrcURL()
+ return nil
+ }
+ return fmt.Errorf("unknown PaymentOrder nullable field %s", name)
+}
+
+// ResetField resets all changes in the mutation for the field with the given name.
+// It returns an error if the field is not defined in the schema.
+func (m *PaymentOrderMutation) ResetField(name string) error {
+ switch name {
+ case paymentorder.FieldUserID:
+ m.ResetUserID()
+ return nil
+ case paymentorder.FieldUserEmail:
+ m.ResetUserEmail()
+ return nil
+ case paymentorder.FieldUserName:
+ m.ResetUserName()
+ return nil
+ case paymentorder.FieldUserNotes:
+ m.ResetUserNotes()
+ return nil
+ case paymentorder.FieldAmount:
+ m.ResetAmount()
+ return nil
+ case paymentorder.FieldPayAmount:
+ m.ResetPayAmount()
+ return nil
+ case paymentorder.FieldFeeRate:
+ m.ResetFeeRate()
+ return nil
+ case paymentorder.FieldRechargeCode:
+ m.ResetRechargeCode()
+ return nil
+ case paymentorder.FieldOutTradeNo:
+ m.ResetOutTradeNo()
+ return nil
+ case paymentorder.FieldPaymentType:
+ m.ResetPaymentType()
+ return nil
+ case paymentorder.FieldPaymentTradeNo:
+ m.ResetPaymentTradeNo()
+ return nil
+ case paymentorder.FieldPayURL:
+ m.ResetPayURL()
+ return nil
+ case paymentorder.FieldQrCode:
+ m.ResetQrCode()
+ return nil
+ case paymentorder.FieldQrCodeImg:
+ m.ResetQrCodeImg()
+ return nil
+ case paymentorder.FieldOrderType:
+ m.ResetOrderType()
+ return nil
+ case paymentorder.FieldPlanID:
+ m.ResetPlanID()
+ return nil
+ case paymentorder.FieldSubscriptionGroupID:
+ m.ResetSubscriptionGroupID()
+ return nil
+ case paymentorder.FieldSubscriptionDays:
+ m.ResetSubscriptionDays()
+ return nil
+ case paymentorder.FieldProviderInstanceID:
+ m.ResetProviderInstanceID()
+ return nil
+ case paymentorder.FieldStatus:
+ m.ResetStatus()
+ return nil
+ case paymentorder.FieldRefundAmount:
+ m.ResetRefundAmount()
+ return nil
+ case paymentorder.FieldRefundReason:
+ m.ResetRefundReason()
+ return nil
+ case paymentorder.FieldRefundAt:
+ m.ResetRefundAt()
return nil
case paymentorder.FieldForceRefund:
- v, ok := value.(bool)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetForceRefund(v)
+ m.ResetForceRefund()
return nil
case paymentorder.FieldRefundRequestedAt:
- v, ok := value.(time.Time)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetRefundRequestedAt(v)
+ m.ResetRefundRequestedAt()
return nil
case paymentorder.FieldRefundRequestReason:
- v, ok := value.(string)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetRefundRequestReason(v)
+ m.ResetRefundRequestReason()
return nil
case paymentorder.FieldRefundRequestedBy:
- v, ok := value.(string)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetRefundRequestedBy(v)
+ m.ResetRefundRequestedBy()
return nil
case paymentorder.FieldExpiresAt:
- v, ok := value.(time.Time)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetExpiresAt(v)
+ m.ResetExpiresAt()
return nil
case paymentorder.FieldPaidAt:
- v, ok := value.(time.Time)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetPaidAt(v)
+ m.ResetPaidAt()
return nil
case paymentorder.FieldCompletedAt:
- v, ok := value.(time.Time)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetCompletedAt(v)
+ m.ResetCompletedAt()
return nil
case paymentorder.FieldFailedAt:
- v, ok := value.(time.Time)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetFailedAt(v)
+ m.ResetFailedAt()
return nil
case paymentorder.FieldFailedReason:
- v, ok := value.(string)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetFailedReason(v)
+ m.ResetFailedReason()
return nil
case paymentorder.FieldClientIP:
- v, ok := value.(string)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
+ m.ResetClientIP()
+ return nil
+ case paymentorder.FieldSrcHost:
+ m.ResetSrcHost()
+ return nil
+ case paymentorder.FieldSrcURL:
+ m.ResetSrcURL()
+ return nil
+ case paymentorder.FieldCreatedAt:
+ m.ResetCreatedAt()
+ return nil
+ case paymentorder.FieldUpdatedAt:
+ m.ResetUpdatedAt()
+ return nil
+ }
+ return fmt.Errorf("unknown PaymentOrder field %s", name)
+}
+
+// AddedEdges returns all edge names that were set/added in this mutation.
+func (m *PaymentOrderMutation) AddedEdges() []string {
+ edges := make([]string, 0, 1)
+ if m.user != nil {
+ edges = append(edges, paymentorder.EdgeUser)
+ }
+ return edges
+}
+
+// AddedIDs returns all IDs (to other nodes) that were added for the given edge
+// name in this mutation.
+func (m *PaymentOrderMutation) AddedIDs(name string) []ent.Value {
+ switch name {
+ case paymentorder.EdgeUser:
+ if id := m.user; id != nil {
+ return []ent.Value{*id}
}
- m.SetClientIP(v)
+ }
+ return nil
+}
+
+// RemovedEdges returns all edge names that were removed in this mutation.
+func (m *PaymentOrderMutation) RemovedEdges() []string {
+ edges := make([]string, 0, 1)
+ return edges
+}
+
+// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
+// the given name in this mutation.
+func (m *PaymentOrderMutation) RemovedIDs(name string) []ent.Value {
+ return nil
+}
+
+// ClearedEdges returns all edge names that were cleared in this mutation.
+func (m *PaymentOrderMutation) ClearedEdges() []string {
+ edges := make([]string, 0, 1)
+ if m.cleareduser {
+ edges = append(edges, paymentorder.EdgeUser)
+ }
+ return edges
+}
+
+// EdgeCleared returns a boolean which indicates if the edge with the given name
+// was cleared in this mutation.
+func (m *PaymentOrderMutation) EdgeCleared(name string) bool {
+ switch name {
+ case paymentorder.EdgeUser:
+ return m.cleareduser
+ }
+ return false
+}
+
+// ClearEdge clears the value of the edge with the given name. It returns an error
+// if that edge is not defined in the schema.
+func (m *PaymentOrderMutation) ClearEdge(name string) error {
+ switch name {
+ case paymentorder.EdgeUser:
+ m.ClearUser()
return nil
- case paymentorder.FieldSrcHost:
- v, ok := value.(string)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetSrcHost(v)
+ }
+ return fmt.Errorf("unknown PaymentOrder unique edge %s", name)
+}
+
+// ResetEdge resets all changes to the edge with the given name in this mutation.
+// It returns an error if the edge is not defined in the schema.
+func (m *PaymentOrderMutation) ResetEdge(name string) error {
+ switch name {
+ case paymentorder.EdgeUser:
+ m.ResetUser()
return nil
- case paymentorder.FieldSrcURL:
- v, ok := value.(string)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ return fmt.Errorf("unknown PaymentOrder edge %s", name)
+}
+
+// PaymentProviderInstanceMutation represents an operation that mutates the PaymentProviderInstance nodes in the graph.
+type PaymentProviderInstanceMutation struct {
+ config
+ op Op
+ typ string
+ id *int64
+ provider_key *string
+ name *string
+ _config *string
+ supported_types *string
+ enabled *bool
+ payment_mode *string
+ sort_order *int
+ addsort_order *int
+ limits *string
+ refund_enabled *bool
+ allow_user_refund *bool
+ created_at *time.Time
+ updated_at *time.Time
+ clearedFields map[string]struct{}
+ done bool
+ oldValue func(context.Context) (*PaymentProviderInstance, error)
+ predicates []predicate.PaymentProviderInstance
+}
+
+var _ ent.Mutation = (*PaymentProviderInstanceMutation)(nil)
+
+// paymentproviderinstanceOption allows management of the mutation configuration using functional options.
+type paymentproviderinstanceOption func(*PaymentProviderInstanceMutation)
+
+// newPaymentProviderInstanceMutation creates new mutation for the PaymentProviderInstance entity.
+func newPaymentProviderInstanceMutation(c config, op Op, opts ...paymentproviderinstanceOption) *PaymentProviderInstanceMutation {
+ m := &PaymentProviderInstanceMutation{
+ config: c,
+ op: op,
+ typ: TypePaymentProviderInstance,
+ clearedFields: make(map[string]struct{}),
+ }
+ for _, opt := range opts {
+ opt(m)
+ }
+ return m
+}
+
+// withPaymentProviderInstanceID sets the ID field of the mutation.
+func withPaymentProviderInstanceID(id int64) paymentproviderinstanceOption {
+ return func(m *PaymentProviderInstanceMutation) {
+ var (
+ err error
+ once sync.Once
+ value *PaymentProviderInstance
+ )
+ m.oldValue = func(ctx context.Context) (*PaymentProviderInstance, error) {
+ once.Do(func() {
+ if m.done {
+ err = errors.New("querying old values post mutation is not allowed")
+ } else {
+ value, err = m.Client().PaymentProviderInstance.Get(ctx, id)
+ }
+ })
+ return value, err
}
- m.SetSrcURL(v)
- return nil
- case paymentorder.FieldCreatedAt:
- v, ok := value.(time.Time)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
+ m.id = &id
+ }
+}
+
+// withPaymentProviderInstance sets the old PaymentProviderInstance of the mutation.
+func withPaymentProviderInstance(node *PaymentProviderInstance) paymentproviderinstanceOption {
+ return func(m *PaymentProviderInstanceMutation) {
+ m.oldValue = func(context.Context) (*PaymentProviderInstance, error) {
+ return node, nil
}
- m.SetCreatedAt(v)
- return nil
- case paymentorder.FieldUpdatedAt:
- v, ok := value.(time.Time)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
+ m.id = &node.ID
+ }
+}
+
+// Client returns a new `ent.Client` from the mutation. If the mutation was
+// executed in a transaction (ent.Tx), a transactional client is returned.
+func (m PaymentProviderInstanceMutation) Client() *Client {
+ client := &Client{config: m.config}
+ client.init()
+ return client
+}
+
+// Tx returns an `ent.Tx` for mutations that were executed in transactions;
+// it returns an error otherwise.
+func (m PaymentProviderInstanceMutation) Tx() (*Tx, error) {
+ if _, ok := m.driver.(*txDriver); !ok {
+ return nil, errors.New("ent: mutation is not running in a transaction")
+ }
+ tx := &Tx{config: m.config}
+ tx.init()
+ return tx, nil
+}
+
+// ID returns the ID value in the mutation. Note that the ID is only available
+// if it was provided to the builder or after it was returned from the database.
+func (m *PaymentProviderInstanceMutation) ID() (id int64, exists bool) {
+ if m.id == nil {
+ return
+ }
+ return *m.id, true
+}
+
+// IDs queries the database and returns the entity ids that match the mutation's predicate.
+// That means, if the mutation is applied within a transaction with an isolation level such
+// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
+// or updated by the mutation.
+func (m *PaymentProviderInstanceMutation) IDs(ctx context.Context) ([]int64, error) {
+ switch {
+ case m.op.Is(OpUpdateOne | OpDeleteOne):
+ id, exists := m.ID()
+ if exists {
+ return []int64{id}, nil
}
- m.SetUpdatedAt(v)
- return nil
+ fallthrough
+ case m.op.Is(OpUpdate | OpDelete):
+ return m.Client().PaymentProviderInstance.Query().Where(m.predicates...).IDs(ctx)
+ default:
+ return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
+ }
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (m *PaymentProviderInstanceMutation) SetProviderKey(s string) {
+ m.provider_key = &s
+}
+
+// ProviderKey returns the value of the "provider_key" field in the mutation.
+func (m *PaymentProviderInstanceMutation) ProviderKey() (r string, exists bool) {
+ v := m.provider_key
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderKey returns the old "provider_key" field's value of the PaymentProviderInstance entity.
+// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentProviderInstanceMutation) OldProviderKey(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderKey is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderKey requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderKey: %w", err)
+ }
+ return oldValue.ProviderKey, nil
+}
+
+// ResetProviderKey resets all changes to the "provider_key" field.
+func (m *PaymentProviderInstanceMutation) ResetProviderKey() {
+ m.provider_key = nil
+}
+
+// SetName sets the "name" field.
+func (m *PaymentProviderInstanceMutation) SetName(s string) {
+ m.name = &s
+}
+
+// Name returns the value of the "name" field in the mutation.
+func (m *PaymentProviderInstanceMutation) Name() (r string, exists bool) {
+ v := m.name
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldName returns the old "name" field's value of the PaymentProviderInstance entity.
+// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentProviderInstanceMutation) OldName(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldName is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldName requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldName: %w", err)
+ }
+ return oldValue.Name, nil
+}
+
+// ResetName resets all changes to the "name" field.
+func (m *PaymentProviderInstanceMutation) ResetName() {
+ m.name = nil
+}
+
+// SetConfig sets the "config" field.
+func (m *PaymentProviderInstanceMutation) SetConfig(s string) {
+ m._config = &s
+}
+
+// Config returns the value of the "config" field in the mutation.
+func (m *PaymentProviderInstanceMutation) Config() (r string, exists bool) {
+ v := m._config
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldConfig returns the old "config" field's value of the PaymentProviderInstance entity.
+// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentProviderInstanceMutation) OldConfig(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldConfig is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldConfig requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldConfig: %w", err)
+ }
+ return oldValue.Config, nil
+}
+
+// ResetConfig resets all changes to the "config" field.
+func (m *PaymentProviderInstanceMutation) ResetConfig() {
+ m._config = nil
+}
+
+// SetSupportedTypes sets the "supported_types" field.
+func (m *PaymentProviderInstanceMutation) SetSupportedTypes(s string) {
+ m.supported_types = &s
+}
+
+// SupportedTypes returns the value of the "supported_types" field in the mutation.
+func (m *PaymentProviderInstanceMutation) SupportedTypes() (r string, exists bool) {
+ v := m.supported_types
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldSupportedTypes returns the old "supported_types" field's value of the PaymentProviderInstance entity.
+// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentProviderInstanceMutation) OldSupportedTypes(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldSupportedTypes is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldSupportedTypes requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldSupportedTypes: %w", err)
+ }
+ return oldValue.SupportedTypes, nil
+}
+
+// ResetSupportedTypes resets all changes to the "supported_types" field.
+func (m *PaymentProviderInstanceMutation) ResetSupportedTypes() {
+ m.supported_types = nil
+}
+
+// SetEnabled sets the "enabled" field.
+func (m *PaymentProviderInstanceMutation) SetEnabled(b bool) {
+ m.enabled = &b
+}
+
+// Enabled returns the value of the "enabled" field in the mutation.
+func (m *PaymentProviderInstanceMutation) Enabled() (r bool, exists bool) {
+ v := m.enabled
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldEnabled returns the old "enabled" field's value of the PaymentProviderInstance entity.
+// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentProviderInstanceMutation) OldEnabled(ctx context.Context) (v bool, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldEnabled is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldEnabled requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldEnabled: %w", err)
+ }
+ return oldValue.Enabled, nil
+}
+
+// ResetEnabled resets all changes to the "enabled" field.
+func (m *PaymentProviderInstanceMutation) ResetEnabled() {
+ m.enabled = nil
+}
+
+// SetPaymentMode sets the "payment_mode" field.
+func (m *PaymentProviderInstanceMutation) SetPaymentMode(s string) {
+ m.payment_mode = &s
+}
+
+// PaymentMode returns the value of the "payment_mode" field in the mutation.
+func (m *PaymentProviderInstanceMutation) PaymentMode() (r string, exists bool) {
+ v := m.payment_mode
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldPaymentMode returns the old "payment_mode" field's value of the PaymentProviderInstance entity.
+// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentProviderInstanceMutation) OldPaymentMode(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldPaymentMode is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldPaymentMode requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldPaymentMode: %w", err)
+ }
+ return oldValue.PaymentMode, nil
+}
+
+// ResetPaymentMode resets all changes to the "payment_mode" field.
+func (m *PaymentProviderInstanceMutation) ResetPaymentMode() {
+ m.payment_mode = nil
+}
+
+// SetSortOrder sets the "sort_order" field.
+func (m *PaymentProviderInstanceMutation) SetSortOrder(i int) {
+ m.sort_order = &i
+ m.addsort_order = nil
+}
+
+// SortOrder returns the value of the "sort_order" field in the mutation.
+func (m *PaymentProviderInstanceMutation) SortOrder() (r int, exists bool) {
+ v := m.sort_order
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldSortOrder returns the old "sort_order" field's value of the PaymentProviderInstance entity.
+// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentProviderInstanceMutation) OldSortOrder(ctx context.Context) (v int, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldSortOrder is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldSortOrder requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldSortOrder: %w", err)
+ }
+ return oldValue.SortOrder, nil
+}
+
+// AddSortOrder adds i to the "sort_order" field.
+func (m *PaymentProviderInstanceMutation) AddSortOrder(i int) {
+ if m.addsort_order != nil {
+ *m.addsort_order += i
+ } else {
+ m.addsort_order = &i
+ }
+}
+
+// AddedSortOrder returns the value that was added to the "sort_order" field in this mutation.
+func (m *PaymentProviderInstanceMutation) AddedSortOrder() (r int, exists bool) {
+ v := m.addsort_order
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetSortOrder resets all changes to the "sort_order" field.
+func (m *PaymentProviderInstanceMutation) ResetSortOrder() {
+ m.sort_order = nil
+ m.addsort_order = nil
+}
+
+// SetLimits sets the "limits" field.
+func (m *PaymentProviderInstanceMutation) SetLimits(s string) {
+ m.limits = &s
+}
+
+// Limits returns the value of the "limits" field in the mutation.
+func (m *PaymentProviderInstanceMutation) Limits() (r string, exists bool) {
+ v := m.limits
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldLimits returns the old "limits" field's value of the PaymentProviderInstance entity.
+// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentProviderInstanceMutation) OldLimits(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldLimits is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldLimits requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldLimits: %w", err)
+ }
+ return oldValue.Limits, nil
+}
+
+// ResetLimits resets all changes to the "limits" field.
+func (m *PaymentProviderInstanceMutation) ResetLimits() {
+ m.limits = nil
+}
+
+// SetRefundEnabled sets the "refund_enabled" field.
+func (m *PaymentProviderInstanceMutation) SetRefundEnabled(b bool) {
+ m.refund_enabled = &b
+}
+
+// RefundEnabled returns the value of the "refund_enabled" field in the mutation.
+func (m *PaymentProviderInstanceMutation) RefundEnabled() (r bool, exists bool) {
+ v := m.refund_enabled
+ if v == nil {
+ return
}
- return fmt.Errorf("unknown PaymentOrder field %s", name)
+ return *v, true
}
-// AddedFields returns all numeric fields that were incremented/decremented during
-// this mutation.
-func (m *PaymentOrderMutation) AddedFields() []string {
- var fields []string
- if m.addamount != nil {
- fields = append(fields, paymentorder.FieldAmount)
+// OldRefundEnabled returns the old "refund_enabled" field's value of the PaymentProviderInstance entity.
+// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentProviderInstanceMutation) OldRefundEnabled(ctx context.Context) (v bool, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldRefundEnabled is only allowed on UpdateOne operations")
}
- if m.addpay_amount != nil {
- fields = append(fields, paymentorder.FieldPayAmount)
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldRefundEnabled requires an ID field in the mutation")
}
- if m.addfee_rate != nil {
- fields = append(fields, paymentorder.FieldFeeRate)
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldRefundEnabled: %w", err)
}
- if m.addplan_id != nil {
- fields = append(fields, paymentorder.FieldPlanID)
+ return oldValue.RefundEnabled, nil
+}
+
+// ResetRefundEnabled resets all changes to the "refund_enabled" field.
+func (m *PaymentProviderInstanceMutation) ResetRefundEnabled() {
+ m.refund_enabled = nil
+}
+
+// SetAllowUserRefund sets the "allow_user_refund" field.
+func (m *PaymentProviderInstanceMutation) SetAllowUserRefund(b bool) {
+ m.allow_user_refund = &b
+}
+
+// AllowUserRefund returns the value of the "allow_user_refund" field in the mutation.
+func (m *PaymentProviderInstanceMutation) AllowUserRefund() (r bool, exists bool) {
+ v := m.allow_user_refund
+ if v == nil {
+ return
}
- if m.addsubscription_group_id != nil {
- fields = append(fields, paymentorder.FieldSubscriptionGroupID)
+ return *v, true
+}
+
+// OldAllowUserRefund returns the old "allow_user_refund" field's value of the PaymentProviderInstance entity.
+// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentProviderInstanceMutation) OldAllowUserRefund(ctx context.Context) (v bool, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldAllowUserRefund is only allowed on UpdateOne operations")
}
- if m.addsubscription_days != nil {
- fields = append(fields, paymentorder.FieldSubscriptionDays)
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldAllowUserRefund requires an ID field in the mutation")
}
- if m.addrefund_amount != nil {
- fields = append(fields, paymentorder.FieldRefundAmount)
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldAllowUserRefund: %w", err)
}
- return fields
+ return oldValue.AllowUserRefund, nil
}
-// AddedField returns the numeric value that was incremented/decremented on a field
-// with the given name. The second boolean return value indicates that this field
-// was not set, or was not defined in the schema.
-func (m *PaymentOrderMutation) AddedField(name string) (ent.Value, bool) {
- switch name {
- case paymentorder.FieldAmount:
- return m.AddedAmount()
- case paymentorder.FieldPayAmount:
- return m.AddedPayAmount()
- case paymentorder.FieldFeeRate:
- return m.AddedFeeRate()
- case paymentorder.FieldPlanID:
- return m.AddedPlanID()
- case paymentorder.FieldSubscriptionGroupID:
- return m.AddedSubscriptionGroupID()
- case paymentorder.FieldSubscriptionDays:
- return m.AddedSubscriptionDays()
- case paymentorder.FieldRefundAmount:
- return m.AddedRefundAmount()
- }
- return nil, false
+// ResetAllowUserRefund resets all changes to the "allow_user_refund" field.
+func (m *PaymentProviderInstanceMutation) ResetAllowUserRefund() {
+ m.allow_user_refund = nil
}
-// AddField adds the value to the field with the given name. It returns an error if
-// the field is not defined in the schema, or if the type mismatched the field
-// type.
-func (m *PaymentOrderMutation) AddField(name string, value ent.Value) error {
- switch name {
- case paymentorder.FieldAmount:
- v, ok := value.(float64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.AddAmount(v)
- return nil
- case paymentorder.FieldPayAmount:
- v, ok := value.(float64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.AddPayAmount(v)
- return nil
- case paymentorder.FieldFeeRate:
- v, ok := value.(float64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.AddFeeRate(v)
- return nil
- case paymentorder.FieldPlanID:
- v, ok := value.(int64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.AddPlanID(v)
- return nil
- case paymentorder.FieldSubscriptionGroupID:
- v, ok := value.(int64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.AddSubscriptionGroupID(v)
- return nil
- case paymentorder.FieldSubscriptionDays:
- v, ok := value.(int)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.AddSubscriptionDays(v)
- return nil
- case paymentorder.FieldRefundAmount:
- v, ok := value.(float64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.AddRefundAmount(v)
- return nil
+// SetCreatedAt sets the "created_at" field.
+func (m *PaymentProviderInstanceMutation) SetCreatedAt(t time.Time) {
+ m.created_at = &t
+}
+
+// CreatedAt returns the value of the "created_at" field in the mutation.
+func (m *PaymentProviderInstanceMutation) CreatedAt() (r time.Time, exists bool) {
+ v := m.created_at
+ if v == nil {
+ return
}
- return fmt.Errorf("unknown PaymentOrder numeric field %s", name)
+ return *v, true
}
-// ClearedFields returns all nullable fields that were cleared during this
-// mutation.
-func (m *PaymentOrderMutation) ClearedFields() []string {
- var fields []string
- if m.FieldCleared(paymentorder.FieldUserNotes) {
- fields = append(fields, paymentorder.FieldUserNotes)
+// OldCreatedAt returns the old "created_at" field's value of the PaymentProviderInstance entity.
+// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentProviderInstanceMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations")
}
- if m.FieldCleared(paymentorder.FieldPayURL) {
- fields = append(fields, paymentorder.FieldPayURL)
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCreatedAt requires an ID field in the mutation")
}
- if m.FieldCleared(paymentorder.FieldQrCode) {
- fields = append(fields, paymentorder.FieldQrCode)
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err)
}
- if m.FieldCleared(paymentorder.FieldQrCodeImg) {
- fields = append(fields, paymentorder.FieldQrCodeImg)
+ return oldValue.CreatedAt, nil
+}
+
+// ResetCreatedAt resets all changes to the "created_at" field.
+func (m *PaymentProviderInstanceMutation) ResetCreatedAt() {
+ m.created_at = nil
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (m *PaymentProviderInstanceMutation) SetUpdatedAt(t time.Time) {
+ m.updated_at = &t
+}
+
+// UpdatedAt returns the value of the "updated_at" field in the mutation.
+func (m *PaymentProviderInstanceMutation) UpdatedAt() (r time.Time, exists bool) {
+ v := m.updated_at
+ if v == nil {
+ return
}
- if m.FieldCleared(paymentorder.FieldPlanID) {
- fields = append(fields, paymentorder.FieldPlanID)
+ return *v, true
+}
+
+// OldUpdatedAt returns the old "updated_at" field's value of the PaymentProviderInstance entity.
+// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentProviderInstanceMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations")
}
- if m.FieldCleared(paymentorder.FieldSubscriptionGroupID) {
- fields = append(fields, paymentorder.FieldSubscriptionGroupID)
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldUpdatedAt requires an ID field in the mutation")
}
- if m.FieldCleared(paymentorder.FieldSubscriptionDays) {
- fields = append(fields, paymentorder.FieldSubscriptionDays)
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err)
+ }
+ return oldValue.UpdatedAt, nil
+}
+
+// ResetUpdatedAt resets all changes to the "updated_at" field.
+func (m *PaymentProviderInstanceMutation) ResetUpdatedAt() {
+ m.updated_at = nil
+}
+
+// Where appends a list predicates to the PaymentProviderInstanceMutation builder.
+func (m *PaymentProviderInstanceMutation) Where(ps ...predicate.PaymentProviderInstance) {
+ m.predicates = append(m.predicates, ps...)
+}
+
+// WhereP appends storage-level predicates to the PaymentProviderInstanceMutation builder. Using this method,
+// users can use type-assertion to append predicates that do not depend on any generated package.
+func (m *PaymentProviderInstanceMutation) WhereP(ps ...func(*sql.Selector)) {
+ p := make([]predicate.PaymentProviderInstance, len(ps))
+ for i := range ps {
+ p[i] = ps[i]
+ }
+ m.Where(p...)
+}
+
+// Op returns the operation name.
+func (m *PaymentProviderInstanceMutation) Op() Op {
+ return m.op
+}
+
+// SetOp allows setting the mutation operation.
+func (m *PaymentProviderInstanceMutation) SetOp(op Op) {
+ m.op = op
+}
+
+// Type returns the node type of this mutation (PaymentProviderInstance).
+func (m *PaymentProviderInstanceMutation) Type() string {
+ return m.typ
+}
+
+// Fields returns all fields that were changed during this mutation. Note that in
+// order to get all numeric fields that were incremented/decremented, call
+// AddedFields().
+func (m *PaymentProviderInstanceMutation) Fields() []string {
+ fields := make([]string, 0, 12)
+ if m.provider_key != nil {
+ fields = append(fields, paymentproviderinstance.FieldProviderKey)
}
- if m.FieldCleared(paymentorder.FieldProviderInstanceID) {
- fields = append(fields, paymentorder.FieldProviderInstanceID)
+ if m.name != nil {
+ fields = append(fields, paymentproviderinstance.FieldName)
}
- if m.FieldCleared(paymentorder.FieldRefundReason) {
- fields = append(fields, paymentorder.FieldRefundReason)
+ if m._config != nil {
+ fields = append(fields, paymentproviderinstance.FieldConfig)
}
- if m.FieldCleared(paymentorder.FieldRefundAt) {
- fields = append(fields, paymentorder.FieldRefundAt)
+ if m.supported_types != nil {
+ fields = append(fields, paymentproviderinstance.FieldSupportedTypes)
}
- if m.FieldCleared(paymentorder.FieldRefundRequestedAt) {
- fields = append(fields, paymentorder.FieldRefundRequestedAt)
+ if m.enabled != nil {
+ fields = append(fields, paymentproviderinstance.FieldEnabled)
}
- if m.FieldCleared(paymentorder.FieldRefundRequestReason) {
- fields = append(fields, paymentorder.FieldRefundRequestReason)
+ if m.payment_mode != nil {
+ fields = append(fields, paymentproviderinstance.FieldPaymentMode)
}
- if m.FieldCleared(paymentorder.FieldRefundRequestedBy) {
- fields = append(fields, paymentorder.FieldRefundRequestedBy)
+ if m.sort_order != nil {
+ fields = append(fields, paymentproviderinstance.FieldSortOrder)
}
- if m.FieldCleared(paymentorder.FieldPaidAt) {
- fields = append(fields, paymentorder.FieldPaidAt)
+ if m.limits != nil {
+ fields = append(fields, paymentproviderinstance.FieldLimits)
}
- if m.FieldCleared(paymentorder.FieldCompletedAt) {
- fields = append(fields, paymentorder.FieldCompletedAt)
+ if m.refund_enabled != nil {
+ fields = append(fields, paymentproviderinstance.FieldRefundEnabled)
}
- if m.FieldCleared(paymentorder.FieldFailedAt) {
- fields = append(fields, paymentorder.FieldFailedAt)
+ if m.allow_user_refund != nil {
+ fields = append(fields, paymentproviderinstance.FieldAllowUserRefund)
}
- if m.FieldCleared(paymentorder.FieldFailedReason) {
- fields = append(fields, paymentorder.FieldFailedReason)
+ if m.created_at != nil {
+ fields = append(fields, paymentproviderinstance.FieldCreatedAt)
}
- if m.FieldCleared(paymentorder.FieldSrcURL) {
- fields = append(fields, paymentorder.FieldSrcURL)
+ if m.updated_at != nil {
+ fields = append(fields, paymentproviderinstance.FieldUpdatedAt)
}
return fields
}
-// FieldCleared returns a boolean indicating if a field with the given name was
-// cleared in this mutation.
-func (m *PaymentOrderMutation) FieldCleared(name string) bool {
- _, ok := m.clearedFields[name]
- return ok
+// Field returns the value of a field with the given name. The second boolean
+// return value indicates that this field was not set, or was not defined in the
+// schema.
+func (m *PaymentProviderInstanceMutation) Field(name string) (ent.Value, bool) {
+ switch name {
+ case paymentproviderinstance.FieldProviderKey:
+ return m.ProviderKey()
+ case paymentproviderinstance.FieldName:
+ return m.Name()
+ case paymentproviderinstance.FieldConfig:
+ return m.Config()
+ case paymentproviderinstance.FieldSupportedTypes:
+ return m.SupportedTypes()
+ case paymentproviderinstance.FieldEnabled:
+ return m.Enabled()
+ case paymentproviderinstance.FieldPaymentMode:
+ return m.PaymentMode()
+ case paymentproviderinstance.FieldSortOrder:
+ return m.SortOrder()
+ case paymentproviderinstance.FieldLimits:
+ return m.Limits()
+ case paymentproviderinstance.FieldRefundEnabled:
+ return m.RefundEnabled()
+ case paymentproviderinstance.FieldAllowUserRefund:
+ return m.AllowUserRefund()
+ case paymentproviderinstance.FieldCreatedAt:
+ return m.CreatedAt()
+ case paymentproviderinstance.FieldUpdatedAt:
+ return m.UpdatedAt()
+ }
+ return nil, false
}
-// ClearField clears the value of the field with the given name. It returns an
-// error if the field is not defined in the schema.
-func (m *PaymentOrderMutation) ClearField(name string) error {
+// OldField returns the old value of the field from the database. An error is
+// returned if the mutation operation is not UpdateOne, or the query to the
+// database failed.
+func (m *PaymentProviderInstanceMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
switch name {
- case paymentorder.FieldUserNotes:
- m.ClearUserNotes()
- return nil
- case paymentorder.FieldPayURL:
- m.ClearPayURL()
- return nil
- case paymentorder.FieldQrCode:
- m.ClearQrCode()
- return nil
- case paymentorder.FieldQrCodeImg:
- m.ClearQrCodeImg()
- return nil
- case paymentorder.FieldPlanID:
- m.ClearPlanID()
- return nil
- case paymentorder.FieldSubscriptionGroupID:
- m.ClearSubscriptionGroupID()
- return nil
- case paymentorder.FieldSubscriptionDays:
- m.ClearSubscriptionDays()
- return nil
- case paymentorder.FieldProviderInstanceID:
- m.ClearProviderInstanceID()
- return nil
- case paymentorder.FieldRefundReason:
- m.ClearRefundReason()
- return nil
- case paymentorder.FieldRefundAt:
- m.ClearRefundAt()
- return nil
- case paymentorder.FieldRefundRequestedAt:
- m.ClearRefundRequestedAt()
- return nil
- case paymentorder.FieldRefundRequestReason:
- m.ClearRefundRequestReason()
- return nil
- case paymentorder.FieldRefundRequestedBy:
- m.ClearRefundRequestedBy()
- return nil
- case paymentorder.FieldPaidAt:
- m.ClearPaidAt()
- return nil
- case paymentorder.FieldCompletedAt:
- m.ClearCompletedAt()
- return nil
- case paymentorder.FieldFailedAt:
- m.ClearFailedAt()
- return nil
- case paymentorder.FieldFailedReason:
- m.ClearFailedReason()
- return nil
- case paymentorder.FieldSrcURL:
- m.ClearSrcURL()
- return nil
+ case paymentproviderinstance.FieldProviderKey:
+ return m.OldProviderKey(ctx)
+ case paymentproviderinstance.FieldName:
+ return m.OldName(ctx)
+ case paymentproviderinstance.FieldConfig:
+ return m.OldConfig(ctx)
+ case paymentproviderinstance.FieldSupportedTypes:
+ return m.OldSupportedTypes(ctx)
+ case paymentproviderinstance.FieldEnabled:
+ return m.OldEnabled(ctx)
+ case paymentproviderinstance.FieldPaymentMode:
+ return m.OldPaymentMode(ctx)
+ case paymentproviderinstance.FieldSortOrder:
+ return m.OldSortOrder(ctx)
+ case paymentproviderinstance.FieldLimits:
+ return m.OldLimits(ctx)
+ case paymentproviderinstance.FieldRefundEnabled:
+ return m.OldRefundEnabled(ctx)
+ case paymentproviderinstance.FieldAllowUserRefund:
+ return m.OldAllowUserRefund(ctx)
+ case paymentproviderinstance.FieldCreatedAt:
+ return m.OldCreatedAt(ctx)
+ case paymentproviderinstance.FieldUpdatedAt:
+ return m.OldUpdatedAt(ctx)
}
- return fmt.Errorf("unknown PaymentOrder nullable field %s", name)
+ return nil, fmt.Errorf("unknown PaymentProviderInstance field %s", name)
}
-// ResetField resets all changes in the mutation for the field with the given name.
-// It returns an error if the field is not defined in the schema.
-func (m *PaymentOrderMutation) ResetField(name string) error {
+// SetField sets the value of a field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *PaymentProviderInstanceMutation) SetField(name string, value ent.Value) error {
switch name {
- case paymentorder.FieldUserID:
- m.ResetUserID()
- return nil
- case paymentorder.FieldUserEmail:
- m.ResetUserEmail()
- return nil
- case paymentorder.FieldUserName:
- m.ResetUserName()
- return nil
- case paymentorder.FieldUserNotes:
- m.ResetUserNotes()
- return nil
- case paymentorder.FieldAmount:
- m.ResetAmount()
- return nil
- case paymentorder.FieldPayAmount:
- m.ResetPayAmount()
- return nil
- case paymentorder.FieldFeeRate:
- m.ResetFeeRate()
- return nil
- case paymentorder.FieldRechargeCode:
- m.ResetRechargeCode()
- return nil
- case paymentorder.FieldOutTradeNo:
- m.ResetOutTradeNo()
- return nil
- case paymentorder.FieldPaymentType:
- m.ResetPaymentType()
- return nil
- case paymentorder.FieldPaymentTradeNo:
- m.ResetPaymentTradeNo()
- return nil
- case paymentorder.FieldPayURL:
- m.ResetPayURL()
- return nil
- case paymentorder.FieldQrCode:
- m.ResetQrCode()
+ case paymentproviderinstance.FieldProviderKey:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderKey(v)
return nil
- case paymentorder.FieldQrCodeImg:
- m.ResetQrCodeImg()
+ case paymentproviderinstance.FieldName:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetName(v)
return nil
- case paymentorder.FieldOrderType:
- m.ResetOrderType()
+ case paymentproviderinstance.FieldConfig:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetConfig(v)
return nil
- case paymentorder.FieldPlanID:
- m.ResetPlanID()
+ case paymentproviderinstance.FieldSupportedTypes:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetSupportedTypes(v)
return nil
- case paymentorder.FieldSubscriptionGroupID:
- m.ResetSubscriptionGroupID()
+ case paymentproviderinstance.FieldEnabled:
+ v, ok := value.(bool)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetEnabled(v)
return nil
- case paymentorder.FieldSubscriptionDays:
- m.ResetSubscriptionDays()
+ case paymentproviderinstance.FieldPaymentMode:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetPaymentMode(v)
return nil
- case paymentorder.FieldProviderInstanceID:
- m.ResetProviderInstanceID()
+ case paymentproviderinstance.FieldSortOrder:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetSortOrder(v)
return nil
- case paymentorder.FieldStatus:
- m.ResetStatus()
+ case paymentproviderinstance.FieldLimits:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetLimits(v)
return nil
- case paymentorder.FieldRefundAmount:
- m.ResetRefundAmount()
+ case paymentproviderinstance.FieldRefundEnabled:
+ v, ok := value.(bool)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetRefundEnabled(v)
return nil
- case paymentorder.FieldRefundReason:
- m.ResetRefundReason()
+ case paymentproviderinstance.FieldAllowUserRefund:
+ v, ok := value.(bool)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetAllowUserRefund(v)
return nil
- case paymentorder.FieldRefundAt:
- m.ResetRefundAt()
+ case paymentproviderinstance.FieldCreatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCreatedAt(v)
return nil
- case paymentorder.FieldForceRefund:
- m.ResetForceRefund()
+ case paymentproviderinstance.FieldUpdatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUpdatedAt(v)
return nil
- case paymentorder.FieldRefundRequestedAt:
- m.ResetRefundRequestedAt()
+ }
+ return fmt.Errorf("unknown PaymentProviderInstance field %s", name)
+}
+
+// AddedFields returns all numeric fields that were incremented/decremented during
+// this mutation.
+func (m *PaymentProviderInstanceMutation) AddedFields() []string {
+ var fields []string
+ if m.addsort_order != nil {
+ fields = append(fields, paymentproviderinstance.FieldSortOrder)
+ }
+ return fields
+}
+
+// AddedField returns the numeric value that was incremented/decremented on a field
+// with the given name. The second boolean return value indicates that this field
+// was not set, or was not defined in the schema.
+func (m *PaymentProviderInstanceMutation) AddedField(name string) (ent.Value, bool) {
+ switch name {
+ case paymentproviderinstance.FieldSortOrder:
+ return m.AddedSortOrder()
+ }
+ return nil, false
+}
+
+// AddField adds the value to the field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *PaymentProviderInstanceMutation) AddField(name string, value ent.Value) error {
+ switch name {
+ case paymentproviderinstance.FieldSortOrder:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddSortOrder(v)
return nil
- case paymentorder.FieldRefundRequestReason:
- m.ResetRefundRequestReason()
+ }
+ return fmt.Errorf("unknown PaymentProviderInstance numeric field %s", name)
+}
+
+// ClearedFields returns all nullable fields that were cleared during this
+// mutation.
+func (m *PaymentProviderInstanceMutation) ClearedFields() []string {
+ return nil
+}
+
+// FieldCleared returns a boolean indicating if a field with the given name was
+// cleared in this mutation.
+func (m *PaymentProviderInstanceMutation) FieldCleared(name string) bool {
+ _, ok := m.clearedFields[name]
+ return ok
+}
+
+// ClearField clears the value of the field with the given name. It returns an
+// error if the field is not defined in the schema.
+func (m *PaymentProviderInstanceMutation) ClearField(name string) error {
+ return fmt.Errorf("unknown PaymentProviderInstance nullable field %s", name)
+}
+
+// ResetField resets all changes in the mutation for the field with the given name.
+// It returns an error if the field is not defined in the schema.
+func (m *PaymentProviderInstanceMutation) ResetField(name string) error {
+ switch name {
+ case paymentproviderinstance.FieldProviderKey:
+ m.ResetProviderKey()
return nil
- case paymentorder.FieldRefundRequestedBy:
- m.ResetRefundRequestedBy()
+ case paymentproviderinstance.FieldName:
+ m.ResetName()
return nil
- case paymentorder.FieldExpiresAt:
- m.ResetExpiresAt()
+ case paymentproviderinstance.FieldConfig:
+ m.ResetConfig()
return nil
- case paymentorder.FieldPaidAt:
- m.ResetPaidAt()
+ case paymentproviderinstance.FieldSupportedTypes:
+ m.ResetSupportedTypes()
return nil
- case paymentorder.FieldCompletedAt:
- m.ResetCompletedAt()
+ case paymentproviderinstance.FieldEnabled:
+ m.ResetEnabled()
return nil
- case paymentorder.FieldFailedAt:
- m.ResetFailedAt()
+ case paymentproviderinstance.FieldPaymentMode:
+ m.ResetPaymentMode()
return nil
- case paymentorder.FieldFailedReason:
- m.ResetFailedReason()
+ case paymentproviderinstance.FieldSortOrder:
+ m.ResetSortOrder()
return nil
- case paymentorder.FieldClientIP:
- m.ResetClientIP()
+ case paymentproviderinstance.FieldLimits:
+ m.ResetLimits()
return nil
- case paymentorder.FieldSrcHost:
- m.ResetSrcHost()
+ case paymentproviderinstance.FieldRefundEnabled:
+ m.ResetRefundEnabled()
return nil
- case paymentorder.FieldSrcURL:
- m.ResetSrcURL()
+ case paymentproviderinstance.FieldAllowUserRefund:
+ m.ResetAllowUserRefund()
return nil
- case paymentorder.FieldCreatedAt:
+ case paymentproviderinstance.FieldCreatedAt:
m.ResetCreatedAt()
return nil
- case paymentorder.FieldUpdatedAt:
+ case paymentproviderinstance.FieldUpdatedAt:
m.ResetUpdatedAt()
return nil
}
- return fmt.Errorf("unknown PaymentOrder field %s", name)
+ return fmt.Errorf("unknown PaymentProviderInstance field %s", name)
}
// AddedEdges returns all edge names that were set/added in this mutation.
-func (m *PaymentOrderMutation) AddedEdges() []string {
- edges := make([]string, 0, 1)
- if m.user != nil {
- edges = append(edges, paymentorder.EdgeUser)
- }
+func (m *PaymentProviderInstanceMutation) AddedEdges() []string {
+ edges := make([]string, 0, 0)
return edges
}
// AddedIDs returns all IDs (to other nodes) that were added for the given edge
// name in this mutation.
-func (m *PaymentOrderMutation) AddedIDs(name string) []ent.Value {
- switch name {
- case paymentorder.EdgeUser:
- if id := m.user; id != nil {
- return []ent.Value{*id}
- }
- }
+func (m *PaymentProviderInstanceMutation) AddedIDs(name string) []ent.Value {
return nil
}
// RemovedEdges returns all edge names that were removed in this mutation.
-func (m *PaymentOrderMutation) RemovedEdges() []string {
- edges := make([]string, 0, 1)
+func (m *PaymentProviderInstanceMutation) RemovedEdges() []string {
+ edges := make([]string, 0, 0)
return edges
}
// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
// the given name in this mutation.
-func (m *PaymentOrderMutation) RemovedIDs(name string) []ent.Value {
+func (m *PaymentProviderInstanceMutation) RemovedIDs(name string) []ent.Value {
return nil
}
// ClearedEdges returns all edge names that were cleared in this mutation.
-func (m *PaymentOrderMutation) ClearedEdges() []string {
- edges := make([]string, 0, 1)
- if m.cleareduser {
- edges = append(edges, paymentorder.EdgeUser)
- }
+func (m *PaymentProviderInstanceMutation) ClearedEdges() []string {
+ edges := make([]string, 0, 0)
return edges
}
// EdgeCleared returns a boolean which indicates if the edge with the given name
// was cleared in this mutation.
-func (m *PaymentOrderMutation) EdgeCleared(name string) bool {
- switch name {
- case paymentorder.EdgeUser:
- return m.cleareduser
- }
- return false
-}
-
-// ClearEdge clears the value of the edge with the given name. It returns an error
-// if that edge is not defined in the schema.
-func (m *PaymentOrderMutation) ClearEdge(name string) error {
- switch name {
- case paymentorder.EdgeUser:
- m.ClearUser()
- return nil
- }
- return fmt.Errorf("unknown PaymentOrder unique edge %s", name)
-}
-
-// ResetEdge resets all changes to the edge with the given name in this mutation.
-// It returns an error if the edge is not defined in the schema.
-func (m *PaymentOrderMutation) ResetEdge(name string) error {
- switch name {
- case paymentorder.EdgeUser:
- m.ResetUser()
- return nil
- }
- return fmt.Errorf("unknown PaymentOrder edge %s", name)
-}
-
-// PaymentProviderInstanceMutation represents an operation that mutates the PaymentProviderInstance nodes in the graph.
-type PaymentProviderInstanceMutation struct {
- config
- op Op
- typ string
- id *int64
- provider_key *string
- name *string
- _config *string
- supported_types *string
- enabled *bool
- payment_mode *string
- sort_order *int
- addsort_order *int
- limits *string
- refund_enabled *bool
- allow_user_refund *bool
- created_at *time.Time
- updated_at *time.Time
- clearedFields map[string]struct{}
- done bool
- oldValue func(context.Context) (*PaymentProviderInstance, error)
- predicates []predicate.PaymentProviderInstance
+func (m *PaymentProviderInstanceMutation) EdgeCleared(name string) bool {
+ return false
}
-var _ ent.Mutation = (*PaymentProviderInstanceMutation)(nil)
+// ClearEdge clears the value of the edge with the given name. It returns an error
+// if that edge is not defined in the schema.
+func (m *PaymentProviderInstanceMutation) ClearEdge(name string) error {
+ return fmt.Errorf("unknown PaymentProviderInstance unique edge %s", name)
+}
-// paymentproviderinstanceOption allows management of the mutation configuration using functional options.
-type paymentproviderinstanceOption func(*PaymentProviderInstanceMutation)
+// ResetEdge resets all changes to the edge with the given name in this mutation.
+// It returns an error if the edge is not defined in the schema.
+func (m *PaymentProviderInstanceMutation) ResetEdge(name string) error {
+ return fmt.Errorf("unknown PaymentProviderInstance edge %s", name)
+}
-// newPaymentProviderInstanceMutation creates new mutation for the PaymentProviderInstance entity.
-func newPaymentProviderInstanceMutation(c config, op Op, opts ...paymentproviderinstanceOption) *PaymentProviderInstanceMutation {
- m := &PaymentProviderInstanceMutation{
+// PendingAuthSessionMutation represents an operation that mutates the PendingAuthSession nodes in the graph.
+type PendingAuthSessionMutation struct {
+ config
+ op Op
+ typ string
+ id *int64
+ created_at *time.Time
+ updated_at *time.Time
+ session_token *string
+ intent *string
+ provider_type *string
+ provider_key *string
+ provider_subject *string
+ redirect_to *string
+ resolved_email *string
+ registration_password_hash *string
+ upstream_identity_claims *map[string]interface{}
+ local_flow_state *map[string]interface{}
+ browser_session_key *string
+ completion_code_hash *string
+ completion_code_expires_at *time.Time
+ email_verified_at *time.Time
+ password_verified_at *time.Time
+ totp_verified_at *time.Time
+ expires_at *time.Time
+ consumed_at *time.Time
+ clearedFields map[string]struct{}
+ target_user *int64
+ clearedtarget_user bool
+ adoption_decision *int64
+ clearedadoption_decision bool
+ done bool
+ oldValue func(context.Context) (*PendingAuthSession, error)
+ predicates []predicate.PendingAuthSession
+}
+
+var _ ent.Mutation = (*PendingAuthSessionMutation)(nil)
+
+// pendingauthsessionOption allows management of the mutation configuration using functional options.
+type pendingauthsessionOption func(*PendingAuthSessionMutation)
+
+// newPendingAuthSessionMutation creates new mutation for the PendingAuthSession entity.
+func newPendingAuthSessionMutation(c config, op Op, opts ...pendingauthsessionOption) *PendingAuthSessionMutation {
+ m := &PendingAuthSessionMutation{
config: c,
op: op,
- typ: TypePaymentProviderInstance,
+ typ: TypePendingAuthSession,
clearedFields: make(map[string]struct{}),
}
for _, opt := range opts {
@@ -15683,20 +19272,20 @@ func newPaymentProviderInstanceMutation(c config, op Op, opts ...paymentprovider
return m
}
-// withPaymentProviderInstanceID sets the ID field of the mutation.
-func withPaymentProviderInstanceID(id int64) paymentproviderinstanceOption {
- return func(m *PaymentProviderInstanceMutation) {
+// withPendingAuthSessionID sets the ID field of the mutation.
+func withPendingAuthSessionID(id int64) pendingauthsessionOption {
+ return func(m *PendingAuthSessionMutation) {
var (
err error
once sync.Once
- value *PaymentProviderInstance
+ value *PendingAuthSession
)
- m.oldValue = func(ctx context.Context) (*PaymentProviderInstance, error) {
+ m.oldValue = func(ctx context.Context) (*PendingAuthSession, error) {
once.Do(func() {
if m.done {
err = errors.New("querying old values post mutation is not allowed")
} else {
- value, err = m.Client().PaymentProviderInstance.Get(ctx, id)
+ value, err = m.Client().PendingAuthSession.Get(ctx, id)
}
})
return value, err
@@ -15705,10 +19294,10 @@ func withPaymentProviderInstanceID(id int64) paymentproviderinstanceOption {
}
}
-// withPaymentProviderInstance sets the old PaymentProviderInstance of the mutation.
-func withPaymentProviderInstance(node *PaymentProviderInstance) paymentproviderinstanceOption {
- return func(m *PaymentProviderInstanceMutation) {
- m.oldValue = func(context.Context) (*PaymentProviderInstance, error) {
+// withPendingAuthSession sets the old PendingAuthSession of the mutation.
+func withPendingAuthSession(node *PendingAuthSession) pendingauthsessionOption {
+ return func(m *PendingAuthSessionMutation) {
+ m.oldValue = func(context.Context) (*PendingAuthSession, error) {
return node, nil
}
m.id = &node.ID
@@ -15717,7 +19306,7 @@ func withPaymentProviderInstance(node *PaymentProviderInstance) paymentprovideri
// Client returns a new `ent.Client` from the mutation. If the mutation was
// executed in a transaction (ent.Tx), a transactional client is returned.
-func (m PaymentProviderInstanceMutation) Client() *Client {
+func (m PendingAuthSessionMutation) Client() *Client {
client := &Client{config: m.config}
client.init()
return client
@@ -15725,7 +19314,7 @@ func (m PaymentProviderInstanceMutation) Client() *Client {
// Tx returns an `ent.Tx` for mutations that were executed in transactions;
// it returns an error otherwise.
-func (m PaymentProviderInstanceMutation) Tx() (*Tx, error) {
+func (m PendingAuthSessionMutation) Tx() (*Tx, error) {
if _, ok := m.driver.(*txDriver); !ok {
return nil, errors.New("ent: mutation is not running in a transaction")
}
@@ -15734,495 +19323,943 @@ func (m PaymentProviderInstanceMutation) Tx() (*Tx, error) {
return tx, nil
}
-// ID returns the ID value in the mutation. Note that the ID is only available
-// if it was provided to the builder or after it was returned from the database.
-func (m *PaymentProviderInstanceMutation) ID() (id int64, exists bool) {
- if m.id == nil {
- return
- }
- return *m.id, true
+// ID returns the ID value in the mutation. Note that the ID is only available
+// if it was provided to the builder or after it was returned from the database.
+func (m *PendingAuthSessionMutation) ID() (id int64, exists bool) {
+ if m.id == nil {
+ return
+ }
+ return *m.id, true
+}
+
+// IDs queries the database and returns the entity ids that match the mutation's predicate.
+// That means, if the mutation is applied within a transaction with an isolation level such
+// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
+// or updated by the mutation.
+func (m *PendingAuthSessionMutation) IDs(ctx context.Context) ([]int64, error) {
+ switch {
+ case m.op.Is(OpUpdateOne | OpDeleteOne):
+ id, exists := m.ID()
+ if exists {
+ return []int64{id}, nil
+ }
+ fallthrough
+ case m.op.Is(OpUpdate | OpDelete):
+ return m.Client().PendingAuthSession.Query().Where(m.predicates...).IDs(ctx)
+ default:
+ return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
+ }
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (m *PendingAuthSessionMutation) SetCreatedAt(t time.Time) {
+ m.created_at = &t
+}
+
+// CreatedAt returns the value of the "created_at" field in the mutation.
+func (m *PendingAuthSessionMutation) CreatedAt() (r time.Time, exists bool) {
+ v := m.created_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCreatedAt returns the old "created_at" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCreatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err)
+ }
+ return oldValue.CreatedAt, nil
+}
+
+// ResetCreatedAt resets all changes to the "created_at" field.
+func (m *PendingAuthSessionMutation) ResetCreatedAt() {
+ m.created_at = nil
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (m *PendingAuthSessionMutation) SetUpdatedAt(t time.Time) {
+ m.updated_at = &t
+}
+
+// UpdatedAt returns the value of the "updated_at" field in the mutation.
+func (m *PendingAuthSessionMutation) UpdatedAt() (r time.Time, exists bool) {
+ v := m.updated_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldUpdatedAt returns the old "updated_at" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldUpdatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err)
+ }
+ return oldValue.UpdatedAt, nil
+}
+
+// ResetUpdatedAt resets all changes to the "updated_at" field.
+func (m *PendingAuthSessionMutation) ResetUpdatedAt() {
+ m.updated_at = nil
+}
+
+// SetSessionToken sets the "session_token" field.
+func (m *PendingAuthSessionMutation) SetSessionToken(s string) {
+ m.session_token = &s
+}
+
+// SessionToken returns the value of the "session_token" field in the mutation.
+func (m *PendingAuthSessionMutation) SessionToken() (r string, exists bool) {
+ v := m.session_token
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldSessionToken returns the old "session_token" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldSessionToken(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldSessionToken is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldSessionToken requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldSessionToken: %w", err)
+ }
+ return oldValue.SessionToken, nil
+}
+
+// ResetSessionToken resets all changes to the "session_token" field.
+func (m *PendingAuthSessionMutation) ResetSessionToken() {
+ m.session_token = nil
+}
+
+// SetIntent sets the "intent" field.
+func (m *PendingAuthSessionMutation) SetIntent(s string) {
+ m.intent = &s
+}
+
+// Intent returns the value of the "intent" field in the mutation.
+func (m *PendingAuthSessionMutation) Intent() (r string, exists bool) {
+ v := m.intent
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldIntent returns the old "intent" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldIntent(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldIntent is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldIntent requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldIntent: %w", err)
+ }
+ return oldValue.Intent, nil
+}
+
+// ResetIntent resets all changes to the "intent" field.
+func (m *PendingAuthSessionMutation) ResetIntent() {
+ m.intent = nil
+}
+
+// SetProviderType sets the "provider_type" field.
+func (m *PendingAuthSessionMutation) SetProviderType(s string) {
+ m.provider_type = &s
+}
+
+// ProviderType returns the value of the "provider_type" field in the mutation.
+func (m *PendingAuthSessionMutation) ProviderType() (r string, exists bool) {
+ v := m.provider_type
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderType returns the old "provider_type" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldProviderType(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderType is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderType requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderType: %w", err)
+ }
+ return oldValue.ProviderType, nil
+}
+
+// ResetProviderType resets all changes to the "provider_type" field.
+func (m *PendingAuthSessionMutation) ResetProviderType() {
+ m.provider_type = nil
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (m *PendingAuthSessionMutation) SetProviderKey(s string) {
+ m.provider_key = &s
+}
+
+// ProviderKey returns the value of the "provider_key" field in the mutation.
+func (m *PendingAuthSessionMutation) ProviderKey() (r string, exists bool) {
+ v := m.provider_key
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderKey returns the old "provider_key" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldProviderKey(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderKey is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderKey requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderKey: %w", err)
+ }
+ return oldValue.ProviderKey, nil
+}
+
+// ResetProviderKey resets all changes to the "provider_key" field.
+func (m *PendingAuthSessionMutation) ResetProviderKey() {
+ m.provider_key = nil
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (m *PendingAuthSessionMutation) SetProviderSubject(s string) {
+ m.provider_subject = &s
+}
+
+// ProviderSubject returns the value of the "provider_subject" field in the mutation.
+func (m *PendingAuthSessionMutation) ProviderSubject() (r string, exists bool) {
+ v := m.provider_subject
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderSubject returns the old "provider_subject" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldProviderSubject(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderSubject is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderSubject requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderSubject: %w", err)
+ }
+ return oldValue.ProviderSubject, nil
+}
+
+// ResetProviderSubject resets all changes to the "provider_subject" field.
+func (m *PendingAuthSessionMutation) ResetProviderSubject() {
+ m.provider_subject = nil
+}
+
+// SetTargetUserID sets the "target_user_id" field.
+func (m *PendingAuthSessionMutation) SetTargetUserID(i int64) {
+ m.target_user = &i
+}
+
+// TargetUserID returns the value of the "target_user_id" field in the mutation.
+func (m *PendingAuthSessionMutation) TargetUserID() (r int64, exists bool) {
+ v := m.target_user
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldTargetUserID returns the old "target_user_id" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldTargetUserID(ctx context.Context) (v *int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldTargetUserID is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldTargetUserID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldTargetUserID: %w", err)
+ }
+ return oldValue.TargetUserID, nil
+}
+
+// ClearTargetUserID clears the value of the "target_user_id" field.
+func (m *PendingAuthSessionMutation) ClearTargetUserID() {
+ m.target_user = nil
+ m.clearedFields[pendingauthsession.FieldTargetUserID] = struct{}{}
}
-// IDs queries the database and returns the entity ids that match the mutation's predicate.
-// That means, if the mutation is applied within a transaction with an isolation level such
-// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
-// or updated by the mutation.
-func (m *PaymentProviderInstanceMutation) IDs(ctx context.Context) ([]int64, error) {
- switch {
- case m.op.Is(OpUpdateOne | OpDeleteOne):
- id, exists := m.ID()
- if exists {
- return []int64{id}, nil
- }
- fallthrough
- case m.op.Is(OpUpdate | OpDelete):
- return m.Client().PaymentProviderInstance.Query().Where(m.predicates...).IDs(ctx)
- default:
- return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
- }
+// TargetUserIDCleared returns if the "target_user_id" field was cleared in this mutation.
+func (m *PendingAuthSessionMutation) TargetUserIDCleared() bool {
+ _, ok := m.clearedFields[pendingauthsession.FieldTargetUserID]
+ return ok
}
-// SetProviderKey sets the "provider_key" field.
-func (m *PaymentProviderInstanceMutation) SetProviderKey(s string) {
- m.provider_key = &s
+// ResetTargetUserID resets all changes to the "target_user_id" field.
+func (m *PendingAuthSessionMutation) ResetTargetUserID() {
+ m.target_user = nil
+ delete(m.clearedFields, pendingauthsession.FieldTargetUserID)
}
-// ProviderKey returns the value of the "provider_key" field in the mutation.
-func (m *PaymentProviderInstanceMutation) ProviderKey() (r string, exists bool) {
- v := m.provider_key
+// SetRedirectTo sets the "redirect_to" field.
+func (m *PendingAuthSessionMutation) SetRedirectTo(s string) {
+ m.redirect_to = &s
+}
+
+// RedirectTo returns the value of the "redirect_to" field in the mutation.
+func (m *PendingAuthSessionMutation) RedirectTo() (r string, exists bool) {
+ v := m.redirect_to
if v == nil {
return
}
return *v, true
}
-// OldProviderKey returns the old "provider_key" field's value of the PaymentProviderInstance entity.
-// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database.
+// OldRedirectTo returns the old "redirect_to" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentProviderInstanceMutation) OldProviderKey(ctx context.Context) (v string, err error) {
+func (m *PendingAuthSessionMutation) OldRedirectTo(ctx context.Context) (v string, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldProviderKey is only allowed on UpdateOne operations")
+ return v, errors.New("OldRedirectTo is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldProviderKey requires an ID field in the mutation")
+ return v, errors.New("OldRedirectTo requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldProviderKey: %w", err)
+ return v, fmt.Errorf("querying old value for OldRedirectTo: %w", err)
}
- return oldValue.ProviderKey, nil
+ return oldValue.RedirectTo, nil
}
-// ResetProviderKey resets all changes to the "provider_key" field.
-func (m *PaymentProviderInstanceMutation) ResetProviderKey() {
- m.provider_key = nil
+// ResetRedirectTo resets all changes to the "redirect_to" field.
+func (m *PendingAuthSessionMutation) ResetRedirectTo() {
+ m.redirect_to = nil
}
-// SetName sets the "name" field.
-func (m *PaymentProviderInstanceMutation) SetName(s string) {
- m.name = &s
+// SetResolvedEmail sets the "resolved_email" field.
+func (m *PendingAuthSessionMutation) SetResolvedEmail(s string) {
+ m.resolved_email = &s
}
-// Name returns the value of the "name" field in the mutation.
-func (m *PaymentProviderInstanceMutation) Name() (r string, exists bool) {
- v := m.name
+// ResolvedEmail returns the value of the "resolved_email" field in the mutation.
+func (m *PendingAuthSessionMutation) ResolvedEmail() (r string, exists bool) {
+ v := m.resolved_email
if v == nil {
return
}
return *v, true
}
-// OldName returns the old "name" field's value of the PaymentProviderInstance entity.
-// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database.
+// OldResolvedEmail returns the old "resolved_email" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentProviderInstanceMutation) OldName(ctx context.Context) (v string, err error) {
+func (m *PendingAuthSessionMutation) OldResolvedEmail(ctx context.Context) (v string, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldName is only allowed on UpdateOne operations")
+ return v, errors.New("OldResolvedEmail is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldName requires an ID field in the mutation")
+ return v, errors.New("OldResolvedEmail requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldName: %w", err)
+ return v, fmt.Errorf("querying old value for OldResolvedEmail: %w", err)
}
- return oldValue.Name, nil
+ return oldValue.ResolvedEmail, nil
}
-// ResetName resets all changes to the "name" field.
-func (m *PaymentProviderInstanceMutation) ResetName() {
- m.name = nil
+// ResetResolvedEmail resets all changes to the "resolved_email" field.
+func (m *PendingAuthSessionMutation) ResetResolvedEmail() {
+ m.resolved_email = nil
}
-// SetConfig sets the "config" field.
-func (m *PaymentProviderInstanceMutation) SetConfig(s string) {
- m._config = &s
+// SetRegistrationPasswordHash sets the "registration_password_hash" field.
+func (m *PendingAuthSessionMutation) SetRegistrationPasswordHash(s string) {
+ m.registration_password_hash = &s
}
-// Config returns the value of the "config" field in the mutation.
-func (m *PaymentProviderInstanceMutation) Config() (r string, exists bool) {
- v := m._config
+// RegistrationPasswordHash returns the value of the "registration_password_hash" field in the mutation.
+func (m *PendingAuthSessionMutation) RegistrationPasswordHash() (r string, exists bool) {
+ v := m.registration_password_hash
if v == nil {
return
}
return *v, true
}
-// OldConfig returns the old "config" field's value of the PaymentProviderInstance entity.
-// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database.
+// OldRegistrationPasswordHash returns the old "registration_password_hash" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentProviderInstanceMutation) OldConfig(ctx context.Context) (v string, err error) {
+func (m *PendingAuthSessionMutation) OldRegistrationPasswordHash(ctx context.Context) (v string, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldConfig is only allowed on UpdateOne operations")
+ return v, errors.New("OldRegistrationPasswordHash is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldConfig requires an ID field in the mutation")
+ return v, errors.New("OldRegistrationPasswordHash requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldConfig: %w", err)
+ return v, fmt.Errorf("querying old value for OldRegistrationPasswordHash: %w", err)
}
- return oldValue.Config, nil
+ return oldValue.RegistrationPasswordHash, nil
}
-// ResetConfig resets all changes to the "config" field.
-func (m *PaymentProviderInstanceMutation) ResetConfig() {
- m._config = nil
+// ResetRegistrationPasswordHash resets all changes to the "registration_password_hash" field.
+func (m *PendingAuthSessionMutation) ResetRegistrationPasswordHash() {
+ m.registration_password_hash = nil
}
-// SetSupportedTypes sets the "supported_types" field.
-func (m *PaymentProviderInstanceMutation) SetSupportedTypes(s string) {
- m.supported_types = &s
+// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field.
+func (m *PendingAuthSessionMutation) SetUpstreamIdentityClaims(value map[string]interface{}) {
+ m.upstream_identity_claims = &value
}
-// SupportedTypes returns the value of the "supported_types" field in the mutation.
-func (m *PaymentProviderInstanceMutation) SupportedTypes() (r string, exists bool) {
- v := m.supported_types
+// UpstreamIdentityClaims returns the value of the "upstream_identity_claims" field in the mutation.
+func (m *PendingAuthSessionMutation) UpstreamIdentityClaims() (r map[string]interface{}, exists bool) {
+ v := m.upstream_identity_claims
if v == nil {
return
}
return *v, true
}
-// OldSupportedTypes returns the old "supported_types" field's value of the PaymentProviderInstance entity.
-// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database.
+// OldUpstreamIdentityClaims returns the old "upstream_identity_claims" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentProviderInstanceMutation) OldSupportedTypes(ctx context.Context) (v string, err error) {
+func (m *PendingAuthSessionMutation) OldUpstreamIdentityClaims(ctx context.Context) (v map[string]interface{}, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldSupportedTypes is only allowed on UpdateOne operations")
+ return v, errors.New("OldUpstreamIdentityClaims is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldSupportedTypes requires an ID field in the mutation")
+ return v, errors.New("OldUpstreamIdentityClaims requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldSupportedTypes: %w", err)
+ return v, fmt.Errorf("querying old value for OldUpstreamIdentityClaims: %w", err)
}
- return oldValue.SupportedTypes, nil
+ return oldValue.UpstreamIdentityClaims, nil
}
-// ResetSupportedTypes resets all changes to the "supported_types" field.
-func (m *PaymentProviderInstanceMutation) ResetSupportedTypes() {
- m.supported_types = nil
+// ResetUpstreamIdentityClaims resets all changes to the "upstream_identity_claims" field.
+func (m *PendingAuthSessionMutation) ResetUpstreamIdentityClaims() {
+ m.upstream_identity_claims = nil
}
-// SetEnabled sets the "enabled" field.
-func (m *PaymentProviderInstanceMutation) SetEnabled(b bool) {
- m.enabled = &b
+// SetLocalFlowState sets the "local_flow_state" field.
+func (m *PendingAuthSessionMutation) SetLocalFlowState(value map[string]interface{}) {
+ m.local_flow_state = &value
}
-// Enabled returns the value of the "enabled" field in the mutation.
-func (m *PaymentProviderInstanceMutation) Enabled() (r bool, exists bool) {
- v := m.enabled
+// LocalFlowState returns the value of the "local_flow_state" field in the mutation.
+func (m *PendingAuthSessionMutation) LocalFlowState() (r map[string]interface{}, exists bool) {
+ v := m.local_flow_state
if v == nil {
return
}
return *v, true
}
-// OldEnabled returns the old "enabled" field's value of the PaymentProviderInstance entity.
-// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database.
+// OldLocalFlowState returns the old "local_flow_state" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentProviderInstanceMutation) OldEnabled(ctx context.Context) (v bool, err error) {
+func (m *PendingAuthSessionMutation) OldLocalFlowState(ctx context.Context) (v map[string]interface{}, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldEnabled is only allowed on UpdateOne operations")
+ return v, errors.New("OldLocalFlowState is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldEnabled requires an ID field in the mutation")
+ return v, errors.New("OldLocalFlowState requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldEnabled: %w", err)
+ return v, fmt.Errorf("querying old value for OldLocalFlowState: %w", err)
}
- return oldValue.Enabled, nil
+ return oldValue.LocalFlowState, nil
}
-// ResetEnabled resets all changes to the "enabled" field.
-func (m *PaymentProviderInstanceMutation) ResetEnabled() {
- m.enabled = nil
+// ResetLocalFlowState resets all changes to the "local_flow_state" field.
+func (m *PendingAuthSessionMutation) ResetLocalFlowState() {
+ m.local_flow_state = nil
}
-// SetPaymentMode sets the "payment_mode" field.
-func (m *PaymentProviderInstanceMutation) SetPaymentMode(s string) {
- m.payment_mode = &s
+// SetBrowserSessionKey sets the "browser_session_key" field.
+func (m *PendingAuthSessionMutation) SetBrowserSessionKey(s string) {
+ m.browser_session_key = &s
}
-// PaymentMode returns the value of the "payment_mode" field in the mutation.
-func (m *PaymentProviderInstanceMutation) PaymentMode() (r string, exists bool) {
- v := m.payment_mode
+// BrowserSessionKey returns the value of the "browser_session_key" field in the mutation.
+func (m *PendingAuthSessionMutation) BrowserSessionKey() (r string, exists bool) {
+ v := m.browser_session_key
if v == nil {
return
}
return *v, true
}
-// OldPaymentMode returns the old "payment_mode" field's value of the PaymentProviderInstance entity.
-// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database.
+// OldBrowserSessionKey returns the old "browser_session_key" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentProviderInstanceMutation) OldPaymentMode(ctx context.Context) (v string, err error) {
+func (m *PendingAuthSessionMutation) OldBrowserSessionKey(ctx context.Context) (v string, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldPaymentMode is only allowed on UpdateOne operations")
+ return v, errors.New("OldBrowserSessionKey is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldPaymentMode requires an ID field in the mutation")
+ return v, errors.New("OldBrowserSessionKey requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldPaymentMode: %w", err)
+ return v, fmt.Errorf("querying old value for OldBrowserSessionKey: %w", err)
}
- return oldValue.PaymentMode, nil
+ return oldValue.BrowserSessionKey, nil
}
-// ResetPaymentMode resets all changes to the "payment_mode" field.
-func (m *PaymentProviderInstanceMutation) ResetPaymentMode() {
- m.payment_mode = nil
+// ResetBrowserSessionKey resets all changes to the "browser_session_key" field.
+func (m *PendingAuthSessionMutation) ResetBrowserSessionKey() {
+ m.browser_session_key = nil
}
-// SetSortOrder sets the "sort_order" field.
-func (m *PaymentProviderInstanceMutation) SetSortOrder(i int) {
- m.sort_order = &i
- m.addsort_order = nil
+// SetCompletionCodeHash sets the "completion_code_hash" field.
+func (m *PendingAuthSessionMutation) SetCompletionCodeHash(s string) {
+ m.completion_code_hash = &s
}
-// SortOrder returns the value of the "sort_order" field in the mutation.
-func (m *PaymentProviderInstanceMutation) SortOrder() (r int, exists bool) {
- v := m.sort_order
+// CompletionCodeHash returns the value of the "completion_code_hash" field in the mutation.
+func (m *PendingAuthSessionMutation) CompletionCodeHash() (r string, exists bool) {
+ v := m.completion_code_hash
if v == nil {
return
}
return *v, true
}
-// OldSortOrder returns the old "sort_order" field's value of the PaymentProviderInstance entity.
-// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database.
+// OldCompletionCodeHash returns the old "completion_code_hash" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentProviderInstanceMutation) OldSortOrder(ctx context.Context) (v int, err error) {
+func (m *PendingAuthSessionMutation) OldCompletionCodeHash(ctx context.Context) (v string, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldSortOrder is only allowed on UpdateOne operations")
+ return v, errors.New("OldCompletionCodeHash is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldSortOrder requires an ID field in the mutation")
+ return v, errors.New("OldCompletionCodeHash requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldSortOrder: %w", err)
+ return v, fmt.Errorf("querying old value for OldCompletionCodeHash: %w", err)
}
- return oldValue.SortOrder, nil
+ return oldValue.CompletionCodeHash, nil
}
-// AddSortOrder adds i to the "sort_order" field.
-func (m *PaymentProviderInstanceMutation) AddSortOrder(i int) {
- if m.addsort_order != nil {
- *m.addsort_order += i
- } else {
- m.addsort_order = &i
- }
+// ResetCompletionCodeHash resets all changes to the "completion_code_hash" field.
+func (m *PendingAuthSessionMutation) ResetCompletionCodeHash() {
+ m.completion_code_hash = nil
}
-// AddedSortOrder returns the value that was added to the "sort_order" field in this mutation.
-func (m *PaymentProviderInstanceMutation) AddedSortOrder() (r int, exists bool) {
- v := m.addsort_order
+// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field.
+func (m *PendingAuthSessionMutation) SetCompletionCodeExpiresAt(t time.Time) {
+ m.completion_code_expires_at = &t
+}
+
+// CompletionCodeExpiresAt returns the value of the "completion_code_expires_at" field in the mutation.
+func (m *PendingAuthSessionMutation) CompletionCodeExpiresAt() (r time.Time, exists bool) {
+ v := m.completion_code_expires_at
if v == nil {
return
}
return *v, true
}
-// ResetSortOrder resets all changes to the "sort_order" field.
-func (m *PaymentProviderInstanceMutation) ResetSortOrder() {
- m.sort_order = nil
- m.addsort_order = nil
+// OldCompletionCodeExpiresAt returns the old "completion_code_expires_at" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldCompletionCodeExpiresAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCompletionCodeExpiresAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCompletionCodeExpiresAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCompletionCodeExpiresAt: %w", err)
+ }
+ return oldValue.CompletionCodeExpiresAt, nil
}
-// SetLimits sets the "limits" field.
-func (m *PaymentProviderInstanceMutation) SetLimits(s string) {
- m.limits = &s
+// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field.
+func (m *PendingAuthSessionMutation) ClearCompletionCodeExpiresAt() {
+ m.completion_code_expires_at = nil
+ m.clearedFields[pendingauthsession.FieldCompletionCodeExpiresAt] = struct{}{}
}
-// Limits returns the value of the "limits" field in the mutation.
-func (m *PaymentProviderInstanceMutation) Limits() (r string, exists bool) {
- v := m.limits
+// CompletionCodeExpiresAtCleared returns if the "completion_code_expires_at" field was cleared in this mutation.
+func (m *PendingAuthSessionMutation) CompletionCodeExpiresAtCleared() bool {
+ _, ok := m.clearedFields[pendingauthsession.FieldCompletionCodeExpiresAt]
+ return ok
+}
+
+// ResetCompletionCodeExpiresAt resets all changes to the "completion_code_expires_at" field.
+func (m *PendingAuthSessionMutation) ResetCompletionCodeExpiresAt() {
+ m.completion_code_expires_at = nil
+ delete(m.clearedFields, pendingauthsession.FieldCompletionCodeExpiresAt)
+}
+
+// SetEmailVerifiedAt sets the "email_verified_at" field.
+func (m *PendingAuthSessionMutation) SetEmailVerifiedAt(t time.Time) {
+ m.email_verified_at = &t
+}
+
+// EmailVerifiedAt returns the value of the "email_verified_at" field in the mutation.
+func (m *PendingAuthSessionMutation) EmailVerifiedAt() (r time.Time, exists bool) {
+ v := m.email_verified_at
if v == nil {
return
}
return *v, true
}
-// OldLimits returns the old "limits" field's value of the PaymentProviderInstance entity.
-// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database.
+// OldEmailVerifiedAt returns the old "email_verified_at" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentProviderInstanceMutation) OldLimits(ctx context.Context) (v string, err error) {
+func (m *PendingAuthSessionMutation) OldEmailVerifiedAt(ctx context.Context) (v *time.Time, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldLimits is only allowed on UpdateOne operations")
+ return v, errors.New("OldEmailVerifiedAt is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldLimits requires an ID field in the mutation")
+ return v, errors.New("OldEmailVerifiedAt requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldLimits: %w", err)
+ return v, fmt.Errorf("querying old value for OldEmailVerifiedAt: %w", err)
}
- return oldValue.Limits, nil
+ return oldValue.EmailVerifiedAt, nil
}
-// ResetLimits resets all changes to the "limits" field.
-func (m *PaymentProviderInstanceMutation) ResetLimits() {
- m.limits = nil
+// ClearEmailVerifiedAt clears the value of the "email_verified_at" field.
+func (m *PendingAuthSessionMutation) ClearEmailVerifiedAt() {
+ m.email_verified_at = nil
+ m.clearedFields[pendingauthsession.FieldEmailVerifiedAt] = struct{}{}
}
-// SetRefundEnabled sets the "refund_enabled" field.
-func (m *PaymentProviderInstanceMutation) SetRefundEnabled(b bool) {
- m.refund_enabled = &b
+// EmailVerifiedAtCleared returns if the "email_verified_at" field was cleared in this mutation.
+func (m *PendingAuthSessionMutation) EmailVerifiedAtCleared() bool {
+ _, ok := m.clearedFields[pendingauthsession.FieldEmailVerifiedAt]
+ return ok
}
-// RefundEnabled returns the value of the "refund_enabled" field in the mutation.
-func (m *PaymentProviderInstanceMutation) RefundEnabled() (r bool, exists bool) {
- v := m.refund_enabled
+// ResetEmailVerifiedAt resets all changes to the "email_verified_at" field.
+func (m *PendingAuthSessionMutation) ResetEmailVerifiedAt() {
+ m.email_verified_at = nil
+ delete(m.clearedFields, pendingauthsession.FieldEmailVerifiedAt)
+}
+
+// SetPasswordVerifiedAt sets the "password_verified_at" field.
+func (m *PendingAuthSessionMutation) SetPasswordVerifiedAt(t time.Time) {
+ m.password_verified_at = &t
+}
+
+// PasswordVerifiedAt returns the value of the "password_verified_at" field in the mutation.
+func (m *PendingAuthSessionMutation) PasswordVerifiedAt() (r time.Time, exists bool) {
+ v := m.password_verified_at
if v == nil {
return
}
return *v, true
}
-// OldRefundEnabled returns the old "refund_enabled" field's value of the PaymentProviderInstance entity.
-// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database.
+// OldPasswordVerifiedAt returns the old "password_verified_at" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentProviderInstanceMutation) OldRefundEnabled(ctx context.Context) (v bool, err error) {
+func (m *PendingAuthSessionMutation) OldPasswordVerifiedAt(ctx context.Context) (v *time.Time, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldRefundEnabled is only allowed on UpdateOne operations")
+ return v, errors.New("OldPasswordVerifiedAt is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldRefundEnabled requires an ID field in the mutation")
+ return v, errors.New("OldPasswordVerifiedAt requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldRefundEnabled: %w", err)
+ return v, fmt.Errorf("querying old value for OldPasswordVerifiedAt: %w", err)
}
- return oldValue.RefundEnabled, nil
+ return oldValue.PasswordVerifiedAt, nil
}
-// ResetRefundEnabled resets all changes to the "refund_enabled" field.
-func (m *PaymentProviderInstanceMutation) ResetRefundEnabled() {
- m.refund_enabled = nil
+// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field.
+func (m *PendingAuthSessionMutation) ClearPasswordVerifiedAt() {
+ m.password_verified_at = nil
+ m.clearedFields[pendingauthsession.FieldPasswordVerifiedAt] = struct{}{}
}
-// SetAllowUserRefund sets the "allow_user_refund" field.
-func (m *PaymentProviderInstanceMutation) SetAllowUserRefund(b bool) {
- m.allow_user_refund = &b
+// PasswordVerifiedAtCleared returns if the "password_verified_at" field was cleared in this mutation.
+func (m *PendingAuthSessionMutation) PasswordVerifiedAtCleared() bool {
+ _, ok := m.clearedFields[pendingauthsession.FieldPasswordVerifiedAt]
+ return ok
}
-// AllowUserRefund returns the value of the "allow_user_refund" field in the mutation.
-func (m *PaymentProviderInstanceMutation) AllowUserRefund() (r bool, exists bool) {
- v := m.allow_user_refund
+// ResetPasswordVerifiedAt resets all changes to the "password_verified_at" field.
+func (m *PendingAuthSessionMutation) ResetPasswordVerifiedAt() {
+ m.password_verified_at = nil
+ delete(m.clearedFields, pendingauthsession.FieldPasswordVerifiedAt)
+}
+
+// SetTotpVerifiedAt sets the "totp_verified_at" field.
+func (m *PendingAuthSessionMutation) SetTotpVerifiedAt(t time.Time) {
+ m.totp_verified_at = &t
+}
+
+// TotpVerifiedAt returns the value of the "totp_verified_at" field in the mutation.
+func (m *PendingAuthSessionMutation) TotpVerifiedAt() (r time.Time, exists bool) {
+ v := m.totp_verified_at
if v == nil {
return
}
return *v, true
}
-// OldAllowUserRefund returns the old "allow_user_refund" field's value of the PaymentProviderInstance entity.
-// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database.
+// OldTotpVerifiedAt returns the old "totp_verified_at" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentProviderInstanceMutation) OldAllowUserRefund(ctx context.Context) (v bool, err error) {
+func (m *PendingAuthSessionMutation) OldTotpVerifiedAt(ctx context.Context) (v *time.Time, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldAllowUserRefund is only allowed on UpdateOne operations")
+ return v, errors.New("OldTotpVerifiedAt is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldAllowUserRefund requires an ID field in the mutation")
+ return v, errors.New("OldTotpVerifiedAt requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldAllowUserRefund: %w", err)
+ return v, fmt.Errorf("querying old value for OldTotpVerifiedAt: %w", err)
}
- return oldValue.AllowUserRefund, nil
+ return oldValue.TotpVerifiedAt, nil
}
-// ResetAllowUserRefund resets all changes to the "allow_user_refund" field.
-func (m *PaymentProviderInstanceMutation) ResetAllowUserRefund() {
- m.allow_user_refund = nil
+// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field.
+func (m *PendingAuthSessionMutation) ClearTotpVerifiedAt() {
+ m.totp_verified_at = nil
+ m.clearedFields[pendingauthsession.FieldTotpVerifiedAt] = struct{}{}
}
-// SetCreatedAt sets the "created_at" field.
-func (m *PaymentProviderInstanceMutation) SetCreatedAt(t time.Time) {
- m.created_at = &t
+// TotpVerifiedAtCleared returns if the "totp_verified_at" field was cleared in this mutation.
+func (m *PendingAuthSessionMutation) TotpVerifiedAtCleared() bool {
+ _, ok := m.clearedFields[pendingauthsession.FieldTotpVerifiedAt]
+ return ok
}
-// CreatedAt returns the value of the "created_at" field in the mutation.
-func (m *PaymentProviderInstanceMutation) CreatedAt() (r time.Time, exists bool) {
- v := m.created_at
+// ResetTotpVerifiedAt resets all changes to the "totp_verified_at" field.
+func (m *PendingAuthSessionMutation) ResetTotpVerifiedAt() {
+ m.totp_verified_at = nil
+ delete(m.clearedFields, pendingauthsession.FieldTotpVerifiedAt)
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (m *PendingAuthSessionMutation) SetExpiresAt(t time.Time) {
+ m.expires_at = &t
+}
+
+// ExpiresAt returns the value of the "expires_at" field in the mutation.
+func (m *PendingAuthSessionMutation) ExpiresAt() (r time.Time, exists bool) {
+ v := m.expires_at
if v == nil {
return
}
return *v, true
}
-// OldCreatedAt returns the old "created_at" field's value of the PaymentProviderInstance entity.
-// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database.
+// OldExpiresAt returns the old "expires_at" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentProviderInstanceMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
+func (m *PendingAuthSessionMutation) OldExpiresAt(ctx context.Context) (v time.Time, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations")
+ return v, errors.New("OldExpiresAt is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldCreatedAt requires an ID field in the mutation")
+ return v, errors.New("OldExpiresAt requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err)
+ return v, fmt.Errorf("querying old value for OldExpiresAt: %w", err)
}
- return oldValue.CreatedAt, nil
+ return oldValue.ExpiresAt, nil
}
-// ResetCreatedAt resets all changes to the "created_at" field.
-func (m *PaymentProviderInstanceMutation) ResetCreatedAt() {
- m.created_at = nil
+// ResetExpiresAt resets all changes to the "expires_at" field.
+func (m *PendingAuthSessionMutation) ResetExpiresAt() {
+ m.expires_at = nil
}
-// SetUpdatedAt sets the "updated_at" field.
-func (m *PaymentProviderInstanceMutation) SetUpdatedAt(t time.Time) {
- m.updated_at = &t
+// SetConsumedAt sets the "consumed_at" field.
+func (m *PendingAuthSessionMutation) SetConsumedAt(t time.Time) {
+ m.consumed_at = &t
}
-// UpdatedAt returns the value of the "updated_at" field in the mutation.
-func (m *PaymentProviderInstanceMutation) UpdatedAt() (r time.Time, exists bool) {
- v := m.updated_at
+// ConsumedAt returns the value of the "consumed_at" field in the mutation.
+func (m *PendingAuthSessionMutation) ConsumedAt() (r time.Time, exists bool) {
+ v := m.consumed_at
if v == nil {
return
}
return *v, true
}
-// OldUpdatedAt returns the old "updated_at" field's value of the PaymentProviderInstance entity.
-// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database.
+// OldConsumedAt returns the old "consumed_at" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *PaymentProviderInstanceMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) {
+func (m *PendingAuthSessionMutation) OldConsumedAt(ctx context.Context) (v *time.Time, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations")
+ return v, errors.New("OldConsumedAt is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldUpdatedAt requires an ID field in the mutation")
+ return v, errors.New("OldConsumedAt requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err)
+ return v, fmt.Errorf("querying old value for OldConsumedAt: %w", err)
}
- return oldValue.UpdatedAt, nil
+ return oldValue.ConsumedAt, nil
}
-// ResetUpdatedAt resets all changes to the "updated_at" field.
-func (m *PaymentProviderInstanceMutation) ResetUpdatedAt() {
- m.updated_at = nil
+// ClearConsumedAt clears the value of the "consumed_at" field.
+func (m *PendingAuthSessionMutation) ClearConsumedAt() {
+ m.consumed_at = nil
+ m.clearedFields[pendingauthsession.FieldConsumedAt] = struct{}{}
}
-// Where appends a list predicates to the PaymentProviderInstanceMutation builder.
-func (m *PaymentProviderInstanceMutation) Where(ps ...predicate.PaymentProviderInstance) {
+// ConsumedAtCleared returns if the "consumed_at" field was cleared in this mutation.
+func (m *PendingAuthSessionMutation) ConsumedAtCleared() bool {
+ _, ok := m.clearedFields[pendingauthsession.FieldConsumedAt]
+ return ok
+}
+
+// ResetConsumedAt resets all changes to the "consumed_at" field.
+func (m *PendingAuthSessionMutation) ResetConsumedAt() {
+ m.consumed_at = nil
+ delete(m.clearedFields, pendingauthsession.FieldConsumedAt)
+}
+
+// ClearTargetUser clears the "target_user" edge to the User entity.
+func (m *PendingAuthSessionMutation) ClearTargetUser() {
+ m.clearedtarget_user = true
+ m.clearedFields[pendingauthsession.FieldTargetUserID] = struct{}{}
+}
+
+// TargetUserCleared reports if the "target_user" edge to the User entity was cleared.
+func (m *PendingAuthSessionMutation) TargetUserCleared() bool {
+ return m.TargetUserIDCleared() || m.clearedtarget_user
+}
+
+// TargetUserIDs returns the "target_user" edge IDs in the mutation.
+// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use
+// TargetUserID instead. It exists only for internal usage by the builders.
+func (m *PendingAuthSessionMutation) TargetUserIDs() (ids []int64) {
+ if id := m.target_user; id != nil {
+ ids = append(ids, *id)
+ }
+ return
+}
+
+// ResetTargetUser resets all changes to the "target_user" edge.
+func (m *PendingAuthSessionMutation) ResetTargetUser() {
+ m.target_user = nil
+ m.clearedtarget_user = false
+}
+
+// SetAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by id.
+func (m *PendingAuthSessionMutation) SetAdoptionDecisionID(id int64) {
+ m.adoption_decision = &id
+}
+
+// ClearAdoptionDecision clears the "adoption_decision" edge to the IdentityAdoptionDecision entity.
+func (m *PendingAuthSessionMutation) ClearAdoptionDecision() {
+ m.clearedadoption_decision = true
+}
+
+// AdoptionDecisionCleared reports if the "adoption_decision" edge to the IdentityAdoptionDecision entity was cleared.
+func (m *PendingAuthSessionMutation) AdoptionDecisionCleared() bool {
+ return m.clearedadoption_decision
+}
+
+// AdoptionDecisionID returns the "adoption_decision" edge ID in the mutation.
+func (m *PendingAuthSessionMutation) AdoptionDecisionID() (id int64, exists bool) {
+ if m.adoption_decision != nil {
+ return *m.adoption_decision, true
+ }
+ return
+}
+
+// AdoptionDecisionIDs returns the "adoption_decision" edge IDs in the mutation.
+// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use
+// AdoptionDecisionID instead. It exists only for internal usage by the builders.
+func (m *PendingAuthSessionMutation) AdoptionDecisionIDs() (ids []int64) {
+ if id := m.adoption_decision; id != nil {
+ ids = append(ids, *id)
+ }
+ return
+}
+
+// ResetAdoptionDecision resets all changes to the "adoption_decision" edge.
+func (m *PendingAuthSessionMutation) ResetAdoptionDecision() {
+ m.adoption_decision = nil
+ m.clearedadoption_decision = false
+}
+
+// Where appends a list predicates to the PendingAuthSessionMutation builder.
+func (m *PendingAuthSessionMutation) Where(ps ...predicate.PendingAuthSession) {
m.predicates = append(m.predicates, ps...)
}
-// WhereP appends storage-level predicates to the PaymentProviderInstanceMutation builder. Using this method,
+// WhereP appends storage-level predicates to the PendingAuthSessionMutation builder. Using this method,
// users can use type-assertion to append predicates that do not depend on any generated package.
-func (m *PaymentProviderInstanceMutation) WhereP(ps ...func(*sql.Selector)) {
- p := make([]predicate.PaymentProviderInstance, len(ps))
+func (m *PendingAuthSessionMutation) WhereP(ps ...func(*sql.Selector)) {
+ p := make([]predicate.PendingAuthSession, len(ps))
for i := range ps {
p[i] = ps[i]
}
@@ -16230,60 +20267,87 @@ func (m *PaymentProviderInstanceMutation) WhereP(ps ...func(*sql.Selector)) {
}
// Op returns the operation name.
-func (m *PaymentProviderInstanceMutation) Op() Op {
+func (m *PendingAuthSessionMutation) Op() Op {
return m.op
}
// SetOp allows setting the mutation operation.
-func (m *PaymentProviderInstanceMutation) SetOp(op Op) {
+func (m *PendingAuthSessionMutation) SetOp(op Op) {
m.op = op
}
-// Type returns the node type of this mutation (PaymentProviderInstance).
-func (m *PaymentProviderInstanceMutation) Type() string {
+// Type returns the node type of this mutation (PendingAuthSession).
+func (m *PendingAuthSessionMutation) Type() string {
return m.typ
}
// Fields returns all fields that were changed during this mutation. Note that in
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
-func (m *PaymentProviderInstanceMutation) Fields() []string {
- fields := make([]string, 0, 12)
+func (m *PendingAuthSessionMutation) Fields() []string {
+ fields := make([]string, 0, 21)
+ if m.created_at != nil {
+ fields = append(fields, pendingauthsession.FieldCreatedAt)
+ }
+ if m.updated_at != nil {
+ fields = append(fields, pendingauthsession.FieldUpdatedAt)
+ }
+ if m.session_token != nil {
+ fields = append(fields, pendingauthsession.FieldSessionToken)
+ }
+ if m.intent != nil {
+ fields = append(fields, pendingauthsession.FieldIntent)
+ }
+ if m.provider_type != nil {
+ fields = append(fields, pendingauthsession.FieldProviderType)
+ }
if m.provider_key != nil {
- fields = append(fields, paymentproviderinstance.FieldProviderKey)
+ fields = append(fields, pendingauthsession.FieldProviderKey)
}
- if m.name != nil {
- fields = append(fields, paymentproviderinstance.FieldName)
+ if m.provider_subject != nil {
+ fields = append(fields, pendingauthsession.FieldProviderSubject)
}
- if m._config != nil {
- fields = append(fields, paymentproviderinstance.FieldConfig)
+ if m.target_user != nil {
+ fields = append(fields, pendingauthsession.FieldTargetUserID)
}
- if m.supported_types != nil {
- fields = append(fields, paymentproviderinstance.FieldSupportedTypes)
+ if m.redirect_to != nil {
+ fields = append(fields, pendingauthsession.FieldRedirectTo)
}
- if m.enabled != nil {
- fields = append(fields, paymentproviderinstance.FieldEnabled)
+ if m.resolved_email != nil {
+ fields = append(fields, pendingauthsession.FieldResolvedEmail)
}
- if m.payment_mode != nil {
- fields = append(fields, paymentproviderinstance.FieldPaymentMode)
+ if m.registration_password_hash != nil {
+ fields = append(fields, pendingauthsession.FieldRegistrationPasswordHash)
}
- if m.sort_order != nil {
- fields = append(fields, paymentproviderinstance.FieldSortOrder)
+ if m.upstream_identity_claims != nil {
+ fields = append(fields, pendingauthsession.FieldUpstreamIdentityClaims)
}
- if m.limits != nil {
- fields = append(fields, paymentproviderinstance.FieldLimits)
+ if m.local_flow_state != nil {
+ fields = append(fields, pendingauthsession.FieldLocalFlowState)
}
- if m.refund_enabled != nil {
- fields = append(fields, paymentproviderinstance.FieldRefundEnabled)
+ if m.browser_session_key != nil {
+ fields = append(fields, pendingauthsession.FieldBrowserSessionKey)
}
- if m.allow_user_refund != nil {
- fields = append(fields, paymentproviderinstance.FieldAllowUserRefund)
+ if m.completion_code_hash != nil {
+ fields = append(fields, pendingauthsession.FieldCompletionCodeHash)
}
- if m.created_at != nil {
- fields = append(fields, paymentproviderinstance.FieldCreatedAt)
+ if m.completion_code_expires_at != nil {
+ fields = append(fields, pendingauthsession.FieldCompletionCodeExpiresAt)
}
- if m.updated_at != nil {
- fields = append(fields, paymentproviderinstance.FieldUpdatedAt)
+ if m.email_verified_at != nil {
+ fields = append(fields, pendingauthsession.FieldEmailVerifiedAt)
+ }
+ if m.password_verified_at != nil {
+ fields = append(fields, pendingauthsession.FieldPasswordVerifiedAt)
+ }
+ if m.totp_verified_at != nil {
+ fields = append(fields, pendingauthsession.FieldTotpVerifiedAt)
+ }
+ if m.expires_at != nil {
+ fields = append(fields, pendingauthsession.FieldExpiresAt)
+ }
+ if m.consumed_at != nil {
+ fields = append(fields, pendingauthsession.FieldConsumedAt)
}
return fields
}
@@ -16291,32 +20355,50 @@ func (m *PaymentProviderInstanceMutation) Fields() []string {
// Field returns the value of a field with the given name. The second boolean
// return value indicates that this field was not set, or was not defined in the
// schema.
-func (m *PaymentProviderInstanceMutation) Field(name string) (ent.Value, bool) {
+func (m *PendingAuthSessionMutation) Field(name string) (ent.Value, bool) {
switch name {
- case paymentproviderinstance.FieldProviderKey:
- return m.ProviderKey()
- case paymentproviderinstance.FieldName:
- return m.Name()
- case paymentproviderinstance.FieldConfig:
- return m.Config()
- case paymentproviderinstance.FieldSupportedTypes:
- return m.SupportedTypes()
- case paymentproviderinstance.FieldEnabled:
- return m.Enabled()
- case paymentproviderinstance.FieldPaymentMode:
- return m.PaymentMode()
- case paymentproviderinstance.FieldSortOrder:
- return m.SortOrder()
- case paymentproviderinstance.FieldLimits:
- return m.Limits()
- case paymentproviderinstance.FieldRefundEnabled:
- return m.RefundEnabled()
- case paymentproviderinstance.FieldAllowUserRefund:
- return m.AllowUserRefund()
- case paymentproviderinstance.FieldCreatedAt:
+ case pendingauthsession.FieldCreatedAt:
return m.CreatedAt()
- case paymentproviderinstance.FieldUpdatedAt:
+ case pendingauthsession.FieldUpdatedAt:
return m.UpdatedAt()
+ case pendingauthsession.FieldSessionToken:
+ return m.SessionToken()
+ case pendingauthsession.FieldIntent:
+ return m.Intent()
+ case pendingauthsession.FieldProviderType:
+ return m.ProviderType()
+ case pendingauthsession.FieldProviderKey:
+ return m.ProviderKey()
+ case pendingauthsession.FieldProviderSubject:
+ return m.ProviderSubject()
+ case pendingauthsession.FieldTargetUserID:
+ return m.TargetUserID()
+ case pendingauthsession.FieldRedirectTo:
+ return m.RedirectTo()
+ case pendingauthsession.FieldResolvedEmail:
+ return m.ResolvedEmail()
+ case pendingauthsession.FieldRegistrationPasswordHash:
+ return m.RegistrationPasswordHash()
+ case pendingauthsession.FieldUpstreamIdentityClaims:
+ return m.UpstreamIdentityClaims()
+ case pendingauthsession.FieldLocalFlowState:
+ return m.LocalFlowState()
+ case pendingauthsession.FieldBrowserSessionKey:
+ return m.BrowserSessionKey()
+ case pendingauthsession.FieldCompletionCodeHash:
+ return m.CompletionCodeHash()
+ case pendingauthsession.FieldCompletionCodeExpiresAt:
+ return m.CompletionCodeExpiresAt()
+ case pendingauthsession.FieldEmailVerifiedAt:
+ return m.EmailVerifiedAt()
+ case pendingauthsession.FieldPasswordVerifiedAt:
+ return m.PasswordVerifiedAt()
+ case pendingauthsession.FieldTotpVerifiedAt:
+ return m.TotpVerifiedAt()
+ case pendingauthsession.FieldExpiresAt:
+ return m.ExpiresAt()
+ case pendingauthsession.FieldConsumedAt:
+ return m.ConsumedAt()
}
return nil, false
}
@@ -16324,146 +20406,222 @@ func (m *PaymentProviderInstanceMutation) Field(name string) (ent.Value, bool) {
// OldField returns the old value of the field from the database. An error is
// returned if the mutation operation is not UpdateOne, or the query to the
// database failed.
-func (m *PaymentProviderInstanceMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
+func (m *PendingAuthSessionMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
switch name {
- case paymentproviderinstance.FieldProviderKey:
- return m.OldProviderKey(ctx)
- case paymentproviderinstance.FieldName:
- return m.OldName(ctx)
- case paymentproviderinstance.FieldConfig:
- return m.OldConfig(ctx)
- case paymentproviderinstance.FieldSupportedTypes:
- return m.OldSupportedTypes(ctx)
- case paymentproviderinstance.FieldEnabled:
- return m.OldEnabled(ctx)
- case paymentproviderinstance.FieldPaymentMode:
- return m.OldPaymentMode(ctx)
- case paymentproviderinstance.FieldSortOrder:
- return m.OldSortOrder(ctx)
- case paymentproviderinstance.FieldLimits:
- return m.OldLimits(ctx)
- case paymentproviderinstance.FieldRefundEnabled:
- return m.OldRefundEnabled(ctx)
- case paymentproviderinstance.FieldAllowUserRefund:
- return m.OldAllowUserRefund(ctx)
- case paymentproviderinstance.FieldCreatedAt:
+ case pendingauthsession.FieldCreatedAt:
return m.OldCreatedAt(ctx)
- case paymentproviderinstance.FieldUpdatedAt:
+ case pendingauthsession.FieldUpdatedAt:
return m.OldUpdatedAt(ctx)
+ case pendingauthsession.FieldSessionToken:
+ return m.OldSessionToken(ctx)
+ case pendingauthsession.FieldIntent:
+ return m.OldIntent(ctx)
+ case pendingauthsession.FieldProviderType:
+ return m.OldProviderType(ctx)
+ case pendingauthsession.FieldProviderKey:
+ return m.OldProviderKey(ctx)
+ case pendingauthsession.FieldProviderSubject:
+ return m.OldProviderSubject(ctx)
+ case pendingauthsession.FieldTargetUserID:
+ return m.OldTargetUserID(ctx)
+ case pendingauthsession.FieldRedirectTo:
+ return m.OldRedirectTo(ctx)
+ case pendingauthsession.FieldResolvedEmail:
+ return m.OldResolvedEmail(ctx)
+ case pendingauthsession.FieldRegistrationPasswordHash:
+ return m.OldRegistrationPasswordHash(ctx)
+ case pendingauthsession.FieldUpstreamIdentityClaims:
+ return m.OldUpstreamIdentityClaims(ctx)
+ case pendingauthsession.FieldLocalFlowState:
+ return m.OldLocalFlowState(ctx)
+ case pendingauthsession.FieldBrowserSessionKey:
+ return m.OldBrowserSessionKey(ctx)
+ case pendingauthsession.FieldCompletionCodeHash:
+ return m.OldCompletionCodeHash(ctx)
+ case pendingauthsession.FieldCompletionCodeExpiresAt:
+ return m.OldCompletionCodeExpiresAt(ctx)
+ case pendingauthsession.FieldEmailVerifiedAt:
+ return m.OldEmailVerifiedAt(ctx)
+ case pendingauthsession.FieldPasswordVerifiedAt:
+ return m.OldPasswordVerifiedAt(ctx)
+ case pendingauthsession.FieldTotpVerifiedAt:
+ return m.OldTotpVerifiedAt(ctx)
+ case pendingauthsession.FieldExpiresAt:
+ return m.OldExpiresAt(ctx)
+ case pendingauthsession.FieldConsumedAt:
+ return m.OldConsumedAt(ctx)
}
- return nil, fmt.Errorf("unknown PaymentProviderInstance field %s", name)
+ return nil, fmt.Errorf("unknown PendingAuthSession field %s", name)
}
// SetField sets the value of a field with the given name. It returns an error if
// the field is not defined in the schema, or if the type mismatched the field
// type.
-func (m *PaymentProviderInstanceMutation) SetField(name string, value ent.Value) error {
+func (m *PendingAuthSessionMutation) SetField(name string, value ent.Value) error {
switch name {
- case paymentproviderinstance.FieldProviderKey:
+ case pendingauthsession.FieldCreatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCreatedAt(v)
+ return nil
+ case pendingauthsession.FieldUpdatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUpdatedAt(v)
+ return nil
+ case pendingauthsession.FieldSessionToken:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
- m.SetProviderKey(v)
+ m.SetSessionToken(v)
return nil
- case paymentproviderinstance.FieldName:
+ case pendingauthsession.FieldIntent:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
- m.SetName(v)
+ m.SetIntent(v)
return nil
- case paymentproviderinstance.FieldConfig:
+ case pendingauthsession.FieldProviderType:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
- m.SetConfig(v)
+ m.SetProviderType(v)
return nil
- case paymentproviderinstance.FieldSupportedTypes:
+ case pendingauthsession.FieldProviderKey:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
- m.SetSupportedTypes(v)
+ m.SetProviderKey(v)
return nil
- case paymentproviderinstance.FieldEnabled:
- v, ok := value.(bool)
+ case pendingauthsession.FieldProviderSubject:
+ v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
- m.SetEnabled(v)
+ m.SetProviderSubject(v)
return nil
- case paymentproviderinstance.FieldPaymentMode:
+ case pendingauthsession.FieldTargetUserID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetTargetUserID(v)
+ return nil
+ case pendingauthsession.FieldRedirectTo:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
- m.SetPaymentMode(v)
+ m.SetRedirectTo(v)
return nil
- case paymentproviderinstance.FieldSortOrder:
- v, ok := value.(int)
+ case pendingauthsession.FieldResolvedEmail:
+ v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
- m.SetSortOrder(v)
+ m.SetResolvedEmail(v)
return nil
- case paymentproviderinstance.FieldLimits:
+ case pendingauthsession.FieldRegistrationPasswordHash:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
- m.SetLimits(v)
+ m.SetRegistrationPasswordHash(v)
return nil
- case paymentproviderinstance.FieldRefundEnabled:
- v, ok := value.(bool)
+ case pendingauthsession.FieldUpstreamIdentityClaims:
+ v, ok := value.(map[string]interface{})
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
- m.SetRefundEnabled(v)
+ m.SetUpstreamIdentityClaims(v)
return nil
- case paymentproviderinstance.FieldAllowUserRefund:
- v, ok := value.(bool)
+ case pendingauthsession.FieldLocalFlowState:
+ v, ok := value.(map[string]interface{})
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
- m.SetAllowUserRefund(v)
+ m.SetLocalFlowState(v)
return nil
- case paymentproviderinstance.FieldCreatedAt:
+ case pendingauthsession.FieldBrowserSessionKey:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetBrowserSessionKey(v)
+ return nil
+ case pendingauthsession.FieldCompletionCodeHash:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCompletionCodeHash(v)
+ return nil
+ case pendingauthsession.FieldCompletionCodeExpiresAt:
v, ok := value.(time.Time)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
- m.SetCreatedAt(v)
+ m.SetCompletionCodeExpiresAt(v)
return nil
- case paymentproviderinstance.FieldUpdatedAt:
+ case pendingauthsession.FieldEmailVerifiedAt:
v, ok := value.(time.Time)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
- m.SetUpdatedAt(v)
+ m.SetEmailVerifiedAt(v)
+ return nil
+ case pendingauthsession.FieldPasswordVerifiedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetPasswordVerifiedAt(v)
+ return nil
+ case pendingauthsession.FieldTotpVerifiedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetTotpVerifiedAt(v)
+ return nil
+ case pendingauthsession.FieldExpiresAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetExpiresAt(v)
+ return nil
+ case pendingauthsession.FieldConsumedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetConsumedAt(v)
return nil
}
- return fmt.Errorf("unknown PaymentProviderInstance field %s", name)
+ return fmt.Errorf("unknown PendingAuthSession field %s", name)
}
// AddedFields returns all numeric fields that were incremented/decremented during
// this mutation.
-func (m *PaymentProviderInstanceMutation) AddedFields() []string {
+func (m *PendingAuthSessionMutation) AddedFields() []string {
var fields []string
- if m.addsort_order != nil {
- fields = append(fields, paymentproviderinstance.FieldSortOrder)
- }
return fields
}
// AddedField returns the numeric value that was incremented/decremented on a field
// with the given name. The second boolean return value indicates that this field
// was not set, or was not defined in the schema.
-func (m *PaymentProviderInstanceMutation) AddedField(name string) (ent.Value, bool) {
+func (m *PendingAuthSessionMutation) AddedField(name string) (ent.Value, bool) {
switch name {
- case paymentproviderinstance.FieldSortOrder:
- return m.AddedSortOrder()
}
return nil, false
}
@@ -16471,128 +20629,231 @@ func (m *PaymentProviderInstanceMutation) AddedField(name string) (ent.Value, bo
// AddField adds the value to the field with the given name. It returns an error if
// the field is not defined in the schema, or if the type mismatched the field
// type.
-func (m *PaymentProviderInstanceMutation) AddField(name string, value ent.Value) error {
+func (m *PendingAuthSessionMutation) AddField(name string, value ent.Value) error {
switch name {
- case paymentproviderinstance.FieldSortOrder:
- v, ok := value.(int)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.AddSortOrder(v)
- return nil
}
- return fmt.Errorf("unknown PaymentProviderInstance numeric field %s", name)
+ return fmt.Errorf("unknown PendingAuthSession numeric field %s", name)
}
// ClearedFields returns all nullable fields that were cleared during this
// mutation.
-func (m *PaymentProviderInstanceMutation) ClearedFields() []string {
- return nil
+func (m *PendingAuthSessionMutation) ClearedFields() []string {
+ var fields []string
+ if m.FieldCleared(pendingauthsession.FieldTargetUserID) {
+ fields = append(fields, pendingauthsession.FieldTargetUserID)
+ }
+ if m.FieldCleared(pendingauthsession.FieldCompletionCodeExpiresAt) {
+ fields = append(fields, pendingauthsession.FieldCompletionCodeExpiresAt)
+ }
+ if m.FieldCleared(pendingauthsession.FieldEmailVerifiedAt) {
+ fields = append(fields, pendingauthsession.FieldEmailVerifiedAt)
+ }
+ if m.FieldCleared(pendingauthsession.FieldPasswordVerifiedAt) {
+ fields = append(fields, pendingauthsession.FieldPasswordVerifiedAt)
+ }
+ if m.FieldCleared(pendingauthsession.FieldTotpVerifiedAt) {
+ fields = append(fields, pendingauthsession.FieldTotpVerifiedAt)
+ }
+ if m.FieldCleared(pendingauthsession.FieldConsumedAt) {
+ fields = append(fields, pendingauthsession.FieldConsumedAt)
+ }
+ return fields
}
// FieldCleared returns a boolean indicating if a field with the given name was
// cleared in this mutation.
-func (m *PaymentProviderInstanceMutation) FieldCleared(name string) bool {
+func (m *PendingAuthSessionMutation) FieldCleared(name string) bool {
_, ok := m.clearedFields[name]
return ok
}
// ClearField clears the value of the field with the given name. It returns an
// error if the field is not defined in the schema.
-func (m *PaymentProviderInstanceMutation) ClearField(name string) error {
- return fmt.Errorf("unknown PaymentProviderInstance nullable field %s", name)
+func (m *PendingAuthSessionMutation) ClearField(name string) error {
+ switch name {
+ case pendingauthsession.FieldTargetUserID:
+ m.ClearTargetUserID()
+ return nil
+ case pendingauthsession.FieldCompletionCodeExpiresAt:
+ m.ClearCompletionCodeExpiresAt()
+ return nil
+ case pendingauthsession.FieldEmailVerifiedAt:
+ m.ClearEmailVerifiedAt()
+ return nil
+ case pendingauthsession.FieldPasswordVerifiedAt:
+ m.ClearPasswordVerifiedAt()
+ return nil
+ case pendingauthsession.FieldTotpVerifiedAt:
+ m.ClearTotpVerifiedAt()
+ return nil
+ case pendingauthsession.FieldConsumedAt:
+ m.ClearConsumedAt()
+ return nil
+ }
+ return fmt.Errorf("unknown PendingAuthSession nullable field %s", name)
}
// ResetField resets all changes in the mutation for the field with the given name.
// It returns an error if the field is not defined in the schema.
-func (m *PaymentProviderInstanceMutation) ResetField(name string) error {
+func (m *PendingAuthSessionMutation) ResetField(name string) error {
switch name {
- case paymentproviderinstance.FieldProviderKey:
+ case pendingauthsession.FieldCreatedAt:
+ m.ResetCreatedAt()
+ return nil
+ case pendingauthsession.FieldUpdatedAt:
+ m.ResetUpdatedAt()
+ return nil
+ case pendingauthsession.FieldSessionToken:
+ m.ResetSessionToken()
+ return nil
+ case pendingauthsession.FieldIntent:
+ m.ResetIntent()
+ return nil
+ case pendingauthsession.FieldProviderType:
+ m.ResetProviderType()
+ return nil
+ case pendingauthsession.FieldProviderKey:
m.ResetProviderKey()
return nil
- case paymentproviderinstance.FieldName:
- m.ResetName()
+ case pendingauthsession.FieldProviderSubject:
+ m.ResetProviderSubject()
return nil
- case paymentproviderinstance.FieldConfig:
- m.ResetConfig()
+ case pendingauthsession.FieldTargetUserID:
+ m.ResetTargetUserID()
return nil
- case paymentproviderinstance.FieldSupportedTypes:
- m.ResetSupportedTypes()
+ case pendingauthsession.FieldRedirectTo:
+ m.ResetRedirectTo()
return nil
- case paymentproviderinstance.FieldEnabled:
- m.ResetEnabled()
+ case pendingauthsession.FieldResolvedEmail:
+ m.ResetResolvedEmail()
return nil
- case paymentproviderinstance.FieldPaymentMode:
- m.ResetPaymentMode()
+ case pendingauthsession.FieldRegistrationPasswordHash:
+ m.ResetRegistrationPasswordHash()
return nil
- case paymentproviderinstance.FieldSortOrder:
- m.ResetSortOrder()
+ case pendingauthsession.FieldUpstreamIdentityClaims:
+ m.ResetUpstreamIdentityClaims()
return nil
- case paymentproviderinstance.FieldLimits:
- m.ResetLimits()
+ case pendingauthsession.FieldLocalFlowState:
+ m.ResetLocalFlowState()
return nil
- case paymentproviderinstance.FieldRefundEnabled:
- m.ResetRefundEnabled()
+ case pendingauthsession.FieldBrowserSessionKey:
+ m.ResetBrowserSessionKey()
return nil
- case paymentproviderinstance.FieldAllowUserRefund:
- m.ResetAllowUserRefund()
+ case pendingauthsession.FieldCompletionCodeHash:
+ m.ResetCompletionCodeHash()
return nil
- case paymentproviderinstance.FieldCreatedAt:
- m.ResetCreatedAt()
+ case pendingauthsession.FieldCompletionCodeExpiresAt:
+ m.ResetCompletionCodeExpiresAt()
return nil
- case paymentproviderinstance.FieldUpdatedAt:
- m.ResetUpdatedAt()
+ case pendingauthsession.FieldEmailVerifiedAt:
+ m.ResetEmailVerifiedAt()
+ return nil
+ case pendingauthsession.FieldPasswordVerifiedAt:
+ m.ResetPasswordVerifiedAt()
+ return nil
+ case pendingauthsession.FieldTotpVerifiedAt:
+ m.ResetTotpVerifiedAt()
+ return nil
+ case pendingauthsession.FieldExpiresAt:
+ m.ResetExpiresAt()
+ return nil
+ case pendingauthsession.FieldConsumedAt:
+ m.ResetConsumedAt()
return nil
}
- return fmt.Errorf("unknown PaymentProviderInstance field %s", name)
+ return fmt.Errorf("unknown PendingAuthSession field %s", name)
}
// AddedEdges returns all edge names that were set/added in this mutation.
-func (m *PaymentProviderInstanceMutation) AddedEdges() []string {
- edges := make([]string, 0, 0)
+func (m *PendingAuthSessionMutation) AddedEdges() []string {
+ edges := make([]string, 0, 2)
+ if m.target_user != nil {
+ edges = append(edges, pendingauthsession.EdgeTargetUser)
+ }
+ if m.adoption_decision != nil {
+ edges = append(edges, pendingauthsession.EdgeAdoptionDecision)
+ }
return edges
}
// AddedIDs returns all IDs (to other nodes) that were added for the given edge
// name in this mutation.
-func (m *PaymentProviderInstanceMutation) AddedIDs(name string) []ent.Value {
+func (m *PendingAuthSessionMutation) AddedIDs(name string) []ent.Value {
+ switch name {
+ case pendingauthsession.EdgeTargetUser:
+ if id := m.target_user; id != nil {
+ return []ent.Value{*id}
+ }
+ case pendingauthsession.EdgeAdoptionDecision:
+ if id := m.adoption_decision; id != nil {
+ return []ent.Value{*id}
+ }
+ }
return nil
}
// RemovedEdges returns all edge names that were removed in this mutation.
-func (m *PaymentProviderInstanceMutation) RemovedEdges() []string {
- edges := make([]string, 0, 0)
+func (m *PendingAuthSessionMutation) RemovedEdges() []string {
+ edges := make([]string, 0, 2)
return edges
}
// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
// the given name in this mutation.
-func (m *PaymentProviderInstanceMutation) RemovedIDs(name string) []ent.Value {
+func (m *PendingAuthSessionMutation) RemovedIDs(name string) []ent.Value {
return nil
}
// ClearedEdges returns all edge names that were cleared in this mutation.
-func (m *PaymentProviderInstanceMutation) ClearedEdges() []string {
- edges := make([]string, 0, 0)
+func (m *PendingAuthSessionMutation) ClearedEdges() []string {
+ edges := make([]string, 0, 2)
+ if m.clearedtarget_user {
+ edges = append(edges, pendingauthsession.EdgeTargetUser)
+ }
+ if m.clearedadoption_decision {
+ edges = append(edges, pendingauthsession.EdgeAdoptionDecision)
+ }
return edges
}
// EdgeCleared returns a boolean which indicates if the edge with the given name
// was cleared in this mutation.
-func (m *PaymentProviderInstanceMutation) EdgeCleared(name string) bool {
+func (m *PendingAuthSessionMutation) EdgeCleared(name string) bool {
+ switch name {
+ case pendingauthsession.EdgeTargetUser:
+ return m.clearedtarget_user
+ case pendingauthsession.EdgeAdoptionDecision:
+ return m.clearedadoption_decision
+ }
return false
}
// ClearEdge clears the value of the edge with the given name. It returns an error
// if that edge is not defined in the schema.
-func (m *PaymentProviderInstanceMutation) ClearEdge(name string) error {
- return fmt.Errorf("unknown PaymentProviderInstance unique edge %s", name)
+func (m *PendingAuthSessionMutation) ClearEdge(name string) error {
+ switch name {
+ case pendingauthsession.EdgeTargetUser:
+ m.ClearTargetUser()
+ return nil
+ case pendingauthsession.EdgeAdoptionDecision:
+ m.ClearAdoptionDecision()
+ return nil
+ }
+ return fmt.Errorf("unknown PendingAuthSession unique edge %s", name)
}
// ResetEdge resets all changes to the edge with the given name in this mutation.
// It returns an error if the edge is not defined in the schema.
-func (m *PaymentProviderInstanceMutation) ResetEdge(name string) error {
- return fmt.Errorf("unknown PaymentProviderInstance edge %s", name)
+func (m *PendingAuthSessionMutation) ResetEdge(name string) error {
+ switch name {
+ case pendingauthsession.EdgeTargetUser:
+ m.ResetTargetUser()
+ return nil
+ case pendingauthsession.EdgeAdoptionDecision:
+ m.ResetAdoptionDecision()
+ return nil
+ }
+ return fmt.Errorf("unknown PendingAuthSession edge %s", name)
}
// PromoCodeMutation represents an operation that mutates the PromoCode nodes in the graph.
@@ -28264,6 +32525,9 @@ type UserMutation struct {
totp_secret_encrypted *string
totp_enabled *bool
totp_enabled_at *time.Time
+ signup_source *string
+ last_login_at *time.Time
+ last_active_at *time.Time
balance_notify_enabled *bool
balance_notify_threshold_type *string
balance_notify_threshold *float64
@@ -28302,6 +32566,12 @@ type UserMutation struct {
payment_orders map[int64]struct{}
removedpayment_orders map[int64]struct{}
clearedpayment_orders bool
+ auth_identities map[int64]struct{}
+ removedauth_identities map[int64]struct{}
+ clearedauth_identities bool
+ pending_auth_sessions map[int64]struct{}
+ removedpending_auth_sessions map[int64]struct{}
+ clearedpending_auth_sessions bool
done bool
oldValue func(context.Context) (*User, error)
predicates []predicate.User
@@ -28988,6 +33258,140 @@ func (m *UserMutation) ResetTotpEnabledAt() {
delete(m.clearedFields, user.FieldTotpEnabledAt)
}
+// SetSignupSource sets the "signup_source" field.
+func (m *UserMutation) SetSignupSource(s string) {
+ m.signup_source = &s
+}
+
+// SignupSource returns the value of the "signup_source" field in the mutation.
+func (m *UserMutation) SignupSource() (r string, exists bool) {
+ v := m.signup_source
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldSignupSource returns the old "signup_source" field's value of the User entity.
+// If the User object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UserMutation) OldSignupSource(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldSignupSource is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldSignupSource requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldSignupSource: %w", err)
+ }
+ return oldValue.SignupSource, nil
+}
+
+// ResetSignupSource resets all changes to the "signup_source" field.
+func (m *UserMutation) ResetSignupSource() {
+ m.signup_source = nil
+}
+
+// SetLastLoginAt sets the "last_login_at" field.
+func (m *UserMutation) SetLastLoginAt(t time.Time) {
+ m.last_login_at = &t
+}
+
+// LastLoginAt returns the value of the "last_login_at" field in the mutation.
+func (m *UserMutation) LastLoginAt() (r time.Time, exists bool) {
+ v := m.last_login_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldLastLoginAt returns the old "last_login_at" field's value of the User entity.
+// If the User object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UserMutation) OldLastLoginAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldLastLoginAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldLastLoginAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldLastLoginAt: %w", err)
+ }
+ return oldValue.LastLoginAt, nil
+}
+
+// ClearLastLoginAt clears the value of the "last_login_at" field.
+func (m *UserMutation) ClearLastLoginAt() {
+ m.last_login_at = nil
+ m.clearedFields[user.FieldLastLoginAt] = struct{}{}
+}
+
+// LastLoginAtCleared returns if the "last_login_at" field was cleared in this mutation.
+func (m *UserMutation) LastLoginAtCleared() bool {
+ _, ok := m.clearedFields[user.FieldLastLoginAt]
+ return ok
+}
+
+// ResetLastLoginAt resets all changes to the "last_login_at" field.
+func (m *UserMutation) ResetLastLoginAt() {
+ m.last_login_at = nil
+ delete(m.clearedFields, user.FieldLastLoginAt)
+}
+
+// SetLastActiveAt sets the "last_active_at" field.
+func (m *UserMutation) SetLastActiveAt(t time.Time) {
+ m.last_active_at = &t
+}
+
+// LastActiveAt returns the value of the "last_active_at" field in the mutation.
+func (m *UserMutation) LastActiveAt() (r time.Time, exists bool) {
+ v := m.last_active_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldLastActiveAt returns the old "last_active_at" field's value of the User entity.
+// If the User object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UserMutation) OldLastActiveAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldLastActiveAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldLastActiveAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldLastActiveAt: %w", err)
+ }
+ return oldValue.LastActiveAt, nil
+}
+
+// ClearLastActiveAt clears the value of the "last_active_at" field.
+func (m *UserMutation) ClearLastActiveAt() {
+ m.last_active_at = nil
+ m.clearedFields[user.FieldLastActiveAt] = struct{}{}
+}
+
+// LastActiveAtCleared returns if the "last_active_at" field was cleared in this mutation.
+func (m *UserMutation) LastActiveAtCleared() bool {
+ _, ok := m.clearedFields[user.FieldLastActiveAt]
+ return ok
+}
+
+// ResetLastActiveAt resets all changes to the "last_active_at" field.
+func (m *UserMutation) ResetLastActiveAt() {
+ m.last_active_at = nil
+ delete(m.clearedFields, user.FieldLastActiveAt)
+}
+
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (m *UserMutation) SetBalanceNotifyEnabled(b bool) {
m.balance_notify_enabled = &b
@@ -29762,6 +34166,114 @@ func (m *UserMutation) ResetPaymentOrders() {
m.removedpayment_orders = nil
}
+// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by ids.
+func (m *UserMutation) AddAuthIdentityIDs(ids ...int64) {
+ if m.auth_identities == nil {
+ m.auth_identities = make(map[int64]struct{})
+ }
+ for i := range ids {
+ m.auth_identities[ids[i]] = struct{}{}
+ }
+}
+
+// ClearAuthIdentities clears the "auth_identities" edge to the AuthIdentity entity.
+func (m *UserMutation) ClearAuthIdentities() {
+ m.clearedauth_identities = true
+}
+
+// AuthIdentitiesCleared reports if the "auth_identities" edge to the AuthIdentity entity was cleared.
+func (m *UserMutation) AuthIdentitiesCleared() bool {
+ return m.clearedauth_identities
+}
+
+// RemoveAuthIdentityIDs removes the "auth_identities" edge to the AuthIdentity entity by IDs.
+func (m *UserMutation) RemoveAuthIdentityIDs(ids ...int64) {
+ if m.removedauth_identities == nil {
+ m.removedauth_identities = make(map[int64]struct{})
+ }
+ for i := range ids {
+ delete(m.auth_identities, ids[i])
+ m.removedauth_identities[ids[i]] = struct{}{}
+ }
+}
+
+// RemovedAuthIdentities returns the removed IDs of the "auth_identities" edge to the AuthIdentity entity.
+func (m *UserMutation) RemovedAuthIdentitiesIDs() (ids []int64) {
+ for id := range m.removedauth_identities {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// AuthIdentitiesIDs returns the "auth_identities" edge IDs in the mutation.
+func (m *UserMutation) AuthIdentitiesIDs() (ids []int64) {
+ for id := range m.auth_identities {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// ResetAuthIdentities resets all changes to the "auth_identities" edge.
+func (m *UserMutation) ResetAuthIdentities() {
+ m.auth_identities = nil
+ m.clearedauth_identities = false
+ m.removedauth_identities = nil
+}
+
+// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by ids.
+func (m *UserMutation) AddPendingAuthSessionIDs(ids ...int64) {
+ if m.pending_auth_sessions == nil {
+ m.pending_auth_sessions = make(map[int64]struct{})
+ }
+ for i := range ids {
+ m.pending_auth_sessions[ids[i]] = struct{}{}
+ }
+}
+
+// ClearPendingAuthSessions clears the "pending_auth_sessions" edge to the PendingAuthSession entity.
+func (m *UserMutation) ClearPendingAuthSessions() {
+ m.clearedpending_auth_sessions = true
+}
+
+// PendingAuthSessionsCleared reports if the "pending_auth_sessions" edge to the PendingAuthSession entity was cleared.
+func (m *UserMutation) PendingAuthSessionsCleared() bool {
+ return m.clearedpending_auth_sessions
+}
+
+// RemovePendingAuthSessionIDs removes the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs.
+func (m *UserMutation) RemovePendingAuthSessionIDs(ids ...int64) {
+ if m.removedpending_auth_sessions == nil {
+ m.removedpending_auth_sessions = make(map[int64]struct{})
+ }
+ for i := range ids {
+ delete(m.pending_auth_sessions, ids[i])
+ m.removedpending_auth_sessions[ids[i]] = struct{}{}
+ }
+}
+
+// RemovedPendingAuthSessions returns the removed IDs of the "pending_auth_sessions" edge to the PendingAuthSession entity.
+func (m *UserMutation) RemovedPendingAuthSessionsIDs() (ids []int64) {
+ for id := range m.removedpending_auth_sessions {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// PendingAuthSessionsIDs returns the "pending_auth_sessions" edge IDs in the mutation.
+func (m *UserMutation) PendingAuthSessionsIDs() (ids []int64) {
+ for id := range m.pending_auth_sessions {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// ResetPendingAuthSessions resets all changes to the "pending_auth_sessions" edge.
+func (m *UserMutation) ResetPendingAuthSessions() {
+ m.pending_auth_sessions = nil
+ m.clearedpending_auth_sessions = false
+ m.removedpending_auth_sessions = nil
+}
+
// Where appends a list predicates to the UserMutation builder.
func (m *UserMutation) Where(ps ...predicate.User) {
m.predicates = append(m.predicates, ps...)
@@ -29796,7 +34308,7 @@ func (m *UserMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *UserMutation) Fields() []string {
- fields := make([]string, 0, 19)
+ fields := make([]string, 0, 22)
if m.created_at != nil {
fields = append(fields, user.FieldCreatedAt)
}
@@ -29839,6 +34351,15 @@ func (m *UserMutation) Fields() []string {
if m.totp_enabled_at != nil {
fields = append(fields, user.FieldTotpEnabledAt)
}
+ if m.signup_source != nil {
+ fields = append(fields, user.FieldSignupSource)
+ }
+ if m.last_login_at != nil {
+ fields = append(fields, user.FieldLastLoginAt)
+ }
+ if m.last_active_at != nil {
+ fields = append(fields, user.FieldLastActiveAt)
+ }
if m.balance_notify_enabled != nil {
fields = append(fields, user.FieldBalanceNotifyEnabled)
}
@@ -29890,6 +34411,12 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) {
return m.TotpEnabled()
case user.FieldTotpEnabledAt:
return m.TotpEnabledAt()
+ case user.FieldSignupSource:
+ return m.SignupSource()
+ case user.FieldLastLoginAt:
+ return m.LastLoginAt()
+ case user.FieldLastActiveAt:
+ return m.LastActiveAt()
case user.FieldBalanceNotifyEnabled:
return m.BalanceNotifyEnabled()
case user.FieldBalanceNotifyThresholdType:
@@ -29937,6 +34464,12 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er
return m.OldTotpEnabled(ctx)
case user.FieldTotpEnabledAt:
return m.OldTotpEnabledAt(ctx)
+ case user.FieldSignupSource:
+ return m.OldSignupSource(ctx)
+ case user.FieldLastLoginAt:
+ return m.OldLastLoginAt(ctx)
+ case user.FieldLastActiveAt:
+ return m.OldLastActiveAt(ctx)
case user.FieldBalanceNotifyEnabled:
return m.OldBalanceNotifyEnabled(ctx)
case user.FieldBalanceNotifyThresholdType:
@@ -30054,6 +34587,27 @@ func (m *UserMutation) SetField(name string, value ent.Value) error {
}
m.SetTotpEnabledAt(v)
return nil
+ case user.FieldSignupSource:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetSignupSource(v)
+ return nil
+ case user.FieldLastLoginAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetLastLoginAt(v)
+ return nil
+ case user.FieldLastActiveAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetLastActiveAt(v)
+ return nil
case user.FieldBalanceNotifyEnabled:
v, ok := value.(bool)
if !ok {
@@ -30179,6 +34733,12 @@ func (m *UserMutation) ClearedFields() []string {
if m.FieldCleared(user.FieldTotpEnabledAt) {
fields = append(fields, user.FieldTotpEnabledAt)
}
+ if m.FieldCleared(user.FieldLastLoginAt) {
+ fields = append(fields, user.FieldLastLoginAt)
+ }
+ if m.FieldCleared(user.FieldLastActiveAt) {
+ fields = append(fields, user.FieldLastActiveAt)
+ }
if m.FieldCleared(user.FieldBalanceNotifyThreshold) {
fields = append(fields, user.FieldBalanceNotifyThreshold)
}
@@ -30205,6 +34765,12 @@ func (m *UserMutation) ClearField(name string) error {
case user.FieldTotpEnabledAt:
m.ClearTotpEnabledAt()
return nil
+ case user.FieldLastLoginAt:
+ m.ClearLastLoginAt()
+ return nil
+ case user.FieldLastActiveAt:
+ m.ClearLastActiveAt()
+ return nil
case user.FieldBalanceNotifyThreshold:
m.ClearBalanceNotifyThreshold()
return nil
@@ -30258,6 +34824,15 @@ func (m *UserMutation) ResetField(name string) error {
case user.FieldTotpEnabledAt:
m.ResetTotpEnabledAt()
return nil
+ case user.FieldSignupSource:
+ m.ResetSignupSource()
+ return nil
+ case user.FieldLastLoginAt:
+ m.ResetLastLoginAt()
+ return nil
+ case user.FieldLastActiveAt:
+ m.ResetLastActiveAt()
+ return nil
case user.FieldBalanceNotifyEnabled:
m.ResetBalanceNotifyEnabled()
return nil
@@ -30279,7 +34854,7 @@ func (m *UserMutation) ResetField(name string) error {
// AddedEdges returns all edge names that were set/added in this mutation.
func (m *UserMutation) AddedEdges() []string {
- edges := make([]string, 0, 10)
+ edges := make([]string, 0, 12)
if m.api_keys != nil {
edges = append(edges, user.EdgeAPIKeys)
}
@@ -30310,6 +34885,12 @@ func (m *UserMutation) AddedEdges() []string {
if m.payment_orders != nil {
edges = append(edges, user.EdgePaymentOrders)
}
+ if m.auth_identities != nil {
+ edges = append(edges, user.EdgeAuthIdentities)
+ }
+ if m.pending_auth_sessions != nil {
+ edges = append(edges, user.EdgePendingAuthSessions)
+ }
return edges
}
@@ -30377,13 +34958,25 @@ func (m *UserMutation) AddedIDs(name string) []ent.Value {
ids = append(ids, id)
}
return ids
+ case user.EdgeAuthIdentities:
+ ids := make([]ent.Value, 0, len(m.auth_identities))
+ for id := range m.auth_identities {
+ ids = append(ids, id)
+ }
+ return ids
+ case user.EdgePendingAuthSessions:
+ ids := make([]ent.Value, 0, len(m.pending_auth_sessions))
+ for id := range m.pending_auth_sessions {
+ ids = append(ids, id)
+ }
+ return ids
}
return nil
}
// RemovedEdges returns all edge names that were removed in this mutation.
func (m *UserMutation) RemovedEdges() []string {
- edges := make([]string, 0, 10)
+ edges := make([]string, 0, 12)
if m.removedapi_keys != nil {
edges = append(edges, user.EdgeAPIKeys)
}
@@ -30414,6 +35007,12 @@ func (m *UserMutation) RemovedEdges() []string {
if m.removedpayment_orders != nil {
edges = append(edges, user.EdgePaymentOrders)
}
+ if m.removedauth_identities != nil {
+ edges = append(edges, user.EdgeAuthIdentities)
+ }
+ if m.removedpending_auth_sessions != nil {
+ edges = append(edges, user.EdgePendingAuthSessions)
+ }
return edges
}
@@ -30481,13 +35080,25 @@ func (m *UserMutation) RemovedIDs(name string) []ent.Value {
ids = append(ids, id)
}
return ids
+ case user.EdgeAuthIdentities:
+ ids := make([]ent.Value, 0, len(m.removedauth_identities))
+ for id := range m.removedauth_identities {
+ ids = append(ids, id)
+ }
+ return ids
+ case user.EdgePendingAuthSessions:
+ ids := make([]ent.Value, 0, len(m.removedpending_auth_sessions))
+ for id := range m.removedpending_auth_sessions {
+ ids = append(ids, id)
+ }
+ return ids
}
return nil
}
// ClearedEdges returns all edge names that were cleared in this mutation.
func (m *UserMutation) ClearedEdges() []string {
- edges := make([]string, 0, 10)
+ edges := make([]string, 0, 12)
if m.clearedapi_keys {
edges = append(edges, user.EdgeAPIKeys)
}
@@ -30518,6 +35129,12 @@ func (m *UserMutation) ClearedEdges() []string {
if m.clearedpayment_orders {
edges = append(edges, user.EdgePaymentOrders)
}
+ if m.clearedauth_identities {
+ edges = append(edges, user.EdgeAuthIdentities)
+ }
+ if m.clearedpending_auth_sessions {
+ edges = append(edges, user.EdgePendingAuthSessions)
+ }
return edges
}
@@ -30545,6 +35162,10 @@ func (m *UserMutation) EdgeCleared(name string) bool {
return m.clearedpromo_code_usages
case user.EdgePaymentOrders:
return m.clearedpayment_orders
+ case user.EdgeAuthIdentities:
+ return m.clearedauth_identities
+ case user.EdgePendingAuthSessions:
+ return m.clearedpending_auth_sessions
}
return false
}
@@ -30591,6 +35212,12 @@ func (m *UserMutation) ResetEdge(name string) error {
case user.EdgePaymentOrders:
m.ResetPaymentOrders()
return nil
+ case user.EdgeAuthIdentities:
+ m.ResetAuthIdentities()
+ return nil
+ case user.EdgePendingAuthSessions:
+ m.ResetPendingAuthSessions()
+ return nil
}
return fmt.Errorf("unknown User edge %s", name)
}
diff --git a/backend/ent/pendingauthsession.go b/backend/ent/pendingauthsession.go
new file mode 100644
index 00000000..e77c065f
--- /dev/null
+++ b/backend/ent/pendingauthsession.go
@@ -0,0 +1,399 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "encoding/json"
+ "fmt"
+ "strings"
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/sql"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// PendingAuthSession is the model entity for the PendingAuthSession schema.
+type PendingAuthSession struct {
+ config `json:"-"`
+ // ID of the ent.
+ ID int64 `json:"id,omitempty"`
+ // CreatedAt holds the value of the "created_at" field.
+ CreatedAt time.Time `json:"created_at,omitempty"`
+ // UpdatedAt holds the value of the "updated_at" field.
+ UpdatedAt time.Time `json:"updated_at,omitempty"`
+ // SessionToken holds the value of the "session_token" field.
+ SessionToken string `json:"session_token,omitempty"`
+ // Intent holds the value of the "intent" field.
+ Intent string `json:"intent,omitempty"`
+ // ProviderType holds the value of the "provider_type" field.
+ ProviderType string `json:"provider_type,omitempty"`
+ // ProviderKey holds the value of the "provider_key" field.
+ ProviderKey string `json:"provider_key,omitempty"`
+ // ProviderSubject holds the value of the "provider_subject" field.
+ ProviderSubject string `json:"provider_subject,omitempty"`
+ // TargetUserID holds the value of the "target_user_id" field.
+ TargetUserID *int64 `json:"target_user_id,omitempty"`
+ // RedirectTo holds the value of the "redirect_to" field.
+ RedirectTo string `json:"redirect_to,omitempty"`
+ // ResolvedEmail holds the value of the "resolved_email" field.
+ ResolvedEmail string `json:"resolved_email,omitempty"`
+ // RegistrationPasswordHash holds the value of the "registration_password_hash" field.
+ RegistrationPasswordHash string `json:"registration_password_hash,omitempty"`
+ // UpstreamIdentityClaims holds the value of the "upstream_identity_claims" field.
+ UpstreamIdentityClaims map[string]interface{} `json:"upstream_identity_claims,omitempty"`
+ // LocalFlowState holds the value of the "local_flow_state" field.
+ LocalFlowState map[string]interface{} `json:"local_flow_state,omitempty"`
+ // BrowserSessionKey holds the value of the "browser_session_key" field.
+ BrowserSessionKey string `json:"browser_session_key,omitempty"`
+ // CompletionCodeHash holds the value of the "completion_code_hash" field.
+ CompletionCodeHash string `json:"completion_code_hash,omitempty"`
+ // CompletionCodeExpiresAt holds the value of the "completion_code_expires_at" field.
+ CompletionCodeExpiresAt *time.Time `json:"completion_code_expires_at,omitempty"`
+ // EmailVerifiedAt holds the value of the "email_verified_at" field.
+ EmailVerifiedAt *time.Time `json:"email_verified_at,omitempty"`
+ // PasswordVerifiedAt holds the value of the "password_verified_at" field.
+ PasswordVerifiedAt *time.Time `json:"password_verified_at,omitempty"`
+ // TotpVerifiedAt holds the value of the "totp_verified_at" field.
+ TotpVerifiedAt *time.Time `json:"totp_verified_at,omitempty"`
+ // ExpiresAt holds the value of the "expires_at" field.
+ ExpiresAt time.Time `json:"expires_at,omitempty"`
+ // ConsumedAt holds the value of the "consumed_at" field.
+ ConsumedAt *time.Time `json:"consumed_at,omitempty"`
+ // Edges holds the relations/edges for other nodes in the graph.
+ // The values are being populated by the PendingAuthSessionQuery when eager-loading is set.
+ Edges PendingAuthSessionEdges `json:"edges"`
+ selectValues sql.SelectValues
+}
+
+// PendingAuthSessionEdges holds the relations/edges for other nodes in the graph.
+type PendingAuthSessionEdges struct {
+ // TargetUser holds the value of the target_user edge.
+ TargetUser *User `json:"target_user,omitempty"`
+ // AdoptionDecision holds the value of the adoption_decision edge.
+ AdoptionDecision *IdentityAdoptionDecision `json:"adoption_decision,omitempty"`
+ // loadedTypes holds the information for reporting if a
+ // type was loaded (or requested) in eager-loading or not.
+ loadedTypes [2]bool
+}
+
+// TargetUserOrErr returns the TargetUser value or an error if the edge
+// was not loaded in eager-loading, or loaded but was not found.
+func (e PendingAuthSessionEdges) TargetUserOrErr() (*User, error) {
+ if e.TargetUser != nil {
+ return e.TargetUser, nil
+ } else if e.loadedTypes[0] {
+ return nil, &NotFoundError{label: user.Label}
+ }
+ return nil, &NotLoadedError{edge: "target_user"}
+}
+
+// AdoptionDecisionOrErr returns the AdoptionDecision value or an error if the edge
+// was not loaded in eager-loading, or loaded but was not found.
+func (e PendingAuthSessionEdges) AdoptionDecisionOrErr() (*IdentityAdoptionDecision, error) {
+ if e.AdoptionDecision != nil {
+ return e.AdoptionDecision, nil
+ } else if e.loadedTypes[1] {
+ return nil, &NotFoundError{label: identityadoptiondecision.Label}
+ }
+ return nil, &NotLoadedError{edge: "adoption_decision"}
+}
+
+// scanValues returns the types for scanning values from sql.Rows.
+func (*PendingAuthSession) scanValues(columns []string) ([]any, error) {
+ values := make([]any, len(columns))
+ for i := range columns {
+ switch columns[i] {
+ case pendingauthsession.FieldUpstreamIdentityClaims, pendingauthsession.FieldLocalFlowState:
+ values[i] = new([]byte)
+ case pendingauthsession.FieldID, pendingauthsession.FieldTargetUserID:
+ values[i] = new(sql.NullInt64)
+ case pendingauthsession.FieldSessionToken, pendingauthsession.FieldIntent, pendingauthsession.FieldProviderType, pendingauthsession.FieldProviderKey, pendingauthsession.FieldProviderSubject, pendingauthsession.FieldRedirectTo, pendingauthsession.FieldResolvedEmail, pendingauthsession.FieldRegistrationPasswordHash, pendingauthsession.FieldBrowserSessionKey, pendingauthsession.FieldCompletionCodeHash:
+ values[i] = new(sql.NullString)
+ case pendingauthsession.FieldCreatedAt, pendingauthsession.FieldUpdatedAt, pendingauthsession.FieldCompletionCodeExpiresAt, pendingauthsession.FieldEmailVerifiedAt, pendingauthsession.FieldPasswordVerifiedAt, pendingauthsession.FieldTotpVerifiedAt, pendingauthsession.FieldExpiresAt, pendingauthsession.FieldConsumedAt:
+ values[i] = new(sql.NullTime)
+ default:
+ values[i] = new(sql.UnknownType)
+ }
+ }
+ return values, nil
+}
+
+// assignValues assigns the values that were returned from sql.Rows (after scanning)
+// to the PendingAuthSession fields.
+func (_m *PendingAuthSession) assignValues(columns []string, values []any) error {
+ if m, n := len(values), len(columns); m < n {
+ return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
+ }
+ for i := range columns {
+ switch columns[i] {
+ case pendingauthsession.FieldID:
+ value, ok := values[i].(*sql.NullInt64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field id", value)
+ }
+ _m.ID = int64(value.Int64)
+ case pendingauthsession.FieldCreatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field created_at", values[i])
+ } else if value.Valid {
+ _m.CreatedAt = value.Time
+ }
+ case pendingauthsession.FieldUpdatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field updated_at", values[i])
+ } else if value.Valid {
+ _m.UpdatedAt = value.Time
+ }
+ case pendingauthsession.FieldSessionToken:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field session_token", values[i])
+ } else if value.Valid {
+ _m.SessionToken = value.String
+ }
+ case pendingauthsession.FieldIntent:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field intent", values[i])
+ } else if value.Valid {
+ _m.Intent = value.String
+ }
+ case pendingauthsession.FieldProviderType:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_type", values[i])
+ } else if value.Valid {
+ _m.ProviderType = value.String
+ }
+ case pendingauthsession.FieldProviderKey:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_key", values[i])
+ } else if value.Valid {
+ _m.ProviderKey = value.String
+ }
+ case pendingauthsession.FieldProviderSubject:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_subject", values[i])
+ } else if value.Valid {
+ _m.ProviderSubject = value.String
+ }
+ case pendingauthsession.FieldTargetUserID:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field target_user_id", values[i])
+ } else if value.Valid {
+ _m.TargetUserID = new(int64)
+ *_m.TargetUserID = value.Int64
+ }
+ case pendingauthsession.FieldRedirectTo:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field redirect_to", values[i])
+ } else if value.Valid {
+ _m.RedirectTo = value.String
+ }
+ case pendingauthsession.FieldResolvedEmail:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field resolved_email", values[i])
+ } else if value.Valid {
+ _m.ResolvedEmail = value.String
+ }
+ case pendingauthsession.FieldRegistrationPasswordHash:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field registration_password_hash", values[i])
+ } else if value.Valid {
+ _m.RegistrationPasswordHash = value.String
+ }
+ case pendingauthsession.FieldUpstreamIdentityClaims:
+ if value, ok := values[i].(*[]byte); !ok {
+ return fmt.Errorf("unexpected type %T for field upstream_identity_claims", values[i])
+ } else if value != nil && len(*value) > 0 {
+ if err := json.Unmarshal(*value, &_m.UpstreamIdentityClaims); err != nil {
+ return fmt.Errorf("unmarshal field upstream_identity_claims: %w", err)
+ }
+ }
+ case pendingauthsession.FieldLocalFlowState:
+ if value, ok := values[i].(*[]byte); !ok {
+ return fmt.Errorf("unexpected type %T for field local_flow_state", values[i])
+ } else if value != nil && len(*value) > 0 {
+ if err := json.Unmarshal(*value, &_m.LocalFlowState); err != nil {
+ return fmt.Errorf("unmarshal field local_flow_state: %w", err)
+ }
+ }
+ case pendingauthsession.FieldBrowserSessionKey:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field browser_session_key", values[i])
+ } else if value.Valid {
+ _m.BrowserSessionKey = value.String
+ }
+ case pendingauthsession.FieldCompletionCodeHash:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field completion_code_hash", values[i])
+ } else if value.Valid {
+ _m.CompletionCodeHash = value.String
+ }
+ case pendingauthsession.FieldCompletionCodeExpiresAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field completion_code_expires_at", values[i])
+ } else if value.Valid {
+ _m.CompletionCodeExpiresAt = new(time.Time)
+ *_m.CompletionCodeExpiresAt = value.Time
+ }
+ case pendingauthsession.FieldEmailVerifiedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field email_verified_at", values[i])
+ } else if value.Valid {
+ _m.EmailVerifiedAt = new(time.Time)
+ *_m.EmailVerifiedAt = value.Time
+ }
+ case pendingauthsession.FieldPasswordVerifiedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field password_verified_at", values[i])
+ } else if value.Valid {
+ _m.PasswordVerifiedAt = new(time.Time)
+ *_m.PasswordVerifiedAt = value.Time
+ }
+ case pendingauthsession.FieldTotpVerifiedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field totp_verified_at", values[i])
+ } else if value.Valid {
+ _m.TotpVerifiedAt = new(time.Time)
+ *_m.TotpVerifiedAt = value.Time
+ }
+ case pendingauthsession.FieldExpiresAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field expires_at", values[i])
+ } else if value.Valid {
+ _m.ExpiresAt = value.Time
+ }
+ case pendingauthsession.FieldConsumedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field consumed_at", values[i])
+ } else if value.Valid {
+ _m.ConsumedAt = new(time.Time)
+ *_m.ConsumedAt = value.Time
+ }
+ default:
+ _m.selectValues.Set(columns[i], values[i])
+ }
+ }
+ return nil
+}
+
+// Value returns the ent.Value that was dynamically selected and assigned to the PendingAuthSession.
+// This includes values selected through modifiers, order, etc.
+func (_m *PendingAuthSession) Value(name string) (ent.Value, error) {
+ return _m.selectValues.Get(name)
+}
+
+// QueryTargetUser queries the "target_user" edge of the PendingAuthSession entity.
+func (_m *PendingAuthSession) QueryTargetUser() *UserQuery {
+ return NewPendingAuthSessionClient(_m.config).QueryTargetUser(_m)
+}
+
+// QueryAdoptionDecision queries the "adoption_decision" edge of the PendingAuthSession entity.
+func (_m *PendingAuthSession) QueryAdoptionDecision() *IdentityAdoptionDecisionQuery {
+ return NewPendingAuthSessionClient(_m.config).QueryAdoptionDecision(_m)
+}
+
+// Update returns a builder for updating this PendingAuthSession.
+// Note that you need to call PendingAuthSession.Unwrap() before calling this method if this PendingAuthSession
+// was returned from a transaction, and the transaction was committed or rolled back.
+func (_m *PendingAuthSession) Update() *PendingAuthSessionUpdateOne {
+ return NewPendingAuthSessionClient(_m.config).UpdateOne(_m)
+}
+
+// Unwrap unwraps the PendingAuthSession entity that was returned from a transaction after it was closed,
+// so that all future queries will be executed through the driver which created the transaction.
+func (_m *PendingAuthSession) Unwrap() *PendingAuthSession {
+ _tx, ok := _m.config.driver.(*txDriver)
+ if !ok {
+ panic("ent: PendingAuthSession is not a transactional entity")
+ }
+ _m.config.driver = _tx.drv
+ return _m
+}
+
+// String implements the fmt.Stringer.
+func (_m *PendingAuthSession) String() string {
+ var builder strings.Builder
+ builder.WriteString("PendingAuthSession(")
+ builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
+ builder.WriteString("created_at=")
+ builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("updated_at=")
+ builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("session_token=")
+ builder.WriteString(_m.SessionToken)
+ builder.WriteString(", ")
+ builder.WriteString("intent=")
+ builder.WriteString(_m.Intent)
+ builder.WriteString(", ")
+ builder.WriteString("provider_type=")
+ builder.WriteString(_m.ProviderType)
+ builder.WriteString(", ")
+ builder.WriteString("provider_key=")
+ builder.WriteString(_m.ProviderKey)
+ builder.WriteString(", ")
+ builder.WriteString("provider_subject=")
+ builder.WriteString(_m.ProviderSubject)
+ builder.WriteString(", ")
+ if v := _m.TargetUserID; v != nil {
+ builder.WriteString("target_user_id=")
+ builder.WriteString(fmt.Sprintf("%v", *v))
+ }
+ builder.WriteString(", ")
+ builder.WriteString("redirect_to=")
+ builder.WriteString(_m.RedirectTo)
+ builder.WriteString(", ")
+ builder.WriteString("resolved_email=")
+ builder.WriteString(_m.ResolvedEmail)
+ builder.WriteString(", ")
+ builder.WriteString("registration_password_hash=")
+ builder.WriteString(_m.RegistrationPasswordHash)
+ builder.WriteString(", ")
+ builder.WriteString("upstream_identity_claims=")
+ builder.WriteString(fmt.Sprintf("%v", _m.UpstreamIdentityClaims))
+ builder.WriteString(", ")
+ builder.WriteString("local_flow_state=")
+ builder.WriteString(fmt.Sprintf("%v", _m.LocalFlowState))
+ builder.WriteString(", ")
+ builder.WriteString("browser_session_key=")
+ builder.WriteString(_m.BrowserSessionKey)
+ builder.WriteString(", ")
+ builder.WriteString("completion_code_hash=")
+ builder.WriteString(_m.CompletionCodeHash)
+ builder.WriteString(", ")
+ if v := _m.CompletionCodeExpiresAt; v != nil {
+ builder.WriteString("completion_code_expires_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
+ if v := _m.EmailVerifiedAt; v != nil {
+ builder.WriteString("email_verified_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
+ if v := _m.PasswordVerifiedAt; v != nil {
+ builder.WriteString("password_verified_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
+ if v := _m.TotpVerifiedAt; v != nil {
+ builder.WriteString("totp_verified_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
+ builder.WriteString("expires_at=")
+ builder.WriteString(_m.ExpiresAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ if v := _m.ConsumedAt; v != nil {
+ builder.WriteString("consumed_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteByte(')')
+ return builder.String()
+}
+
+// PendingAuthSessions is a parsable slice of PendingAuthSession.
+type PendingAuthSessions []*PendingAuthSession
diff --git a/backend/ent/pendingauthsession/pendingauthsession.go b/backend/ent/pendingauthsession/pendingauthsession.go
new file mode 100644
index 00000000..8a3ac9bf
--- /dev/null
+++ b/backend/ent/pendingauthsession/pendingauthsession.go
@@ -0,0 +1,279 @@
+// Code generated by ent, DO NOT EDIT.
+
+package pendingauthsession
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+)
+
+const (
+ // Label holds the string label denoting the pendingauthsession type in the database.
+ Label = "pending_auth_session"
+ // FieldID holds the string denoting the id field in the database.
+ FieldID = "id"
+ // FieldCreatedAt holds the string denoting the created_at field in the database.
+ FieldCreatedAt = "created_at"
+ // FieldUpdatedAt holds the string denoting the updated_at field in the database.
+ FieldUpdatedAt = "updated_at"
+ // FieldSessionToken holds the string denoting the session_token field in the database.
+ FieldSessionToken = "session_token"
+ // FieldIntent holds the string denoting the intent field in the database.
+ FieldIntent = "intent"
+ // FieldProviderType holds the string denoting the provider_type field in the database.
+ FieldProviderType = "provider_type"
+ // FieldProviderKey holds the string denoting the provider_key field in the database.
+ FieldProviderKey = "provider_key"
+ // FieldProviderSubject holds the string denoting the provider_subject field in the database.
+ FieldProviderSubject = "provider_subject"
+ // FieldTargetUserID holds the string denoting the target_user_id field in the database.
+ FieldTargetUserID = "target_user_id"
+ // FieldRedirectTo holds the string denoting the redirect_to field in the database.
+ FieldRedirectTo = "redirect_to"
+ // FieldResolvedEmail holds the string denoting the resolved_email field in the database.
+ FieldResolvedEmail = "resolved_email"
+ // FieldRegistrationPasswordHash holds the string denoting the registration_password_hash field in the database.
+ FieldRegistrationPasswordHash = "registration_password_hash"
+ // FieldUpstreamIdentityClaims holds the string denoting the upstream_identity_claims field in the database.
+ FieldUpstreamIdentityClaims = "upstream_identity_claims"
+ // FieldLocalFlowState holds the string denoting the local_flow_state field in the database.
+ FieldLocalFlowState = "local_flow_state"
+ // FieldBrowserSessionKey holds the string denoting the browser_session_key field in the database.
+ FieldBrowserSessionKey = "browser_session_key"
+ // FieldCompletionCodeHash holds the string denoting the completion_code_hash field in the database.
+ FieldCompletionCodeHash = "completion_code_hash"
+ // FieldCompletionCodeExpiresAt holds the string denoting the completion_code_expires_at field in the database.
+ FieldCompletionCodeExpiresAt = "completion_code_expires_at"
+ // FieldEmailVerifiedAt holds the string denoting the email_verified_at field in the database.
+ FieldEmailVerifiedAt = "email_verified_at"
+ // FieldPasswordVerifiedAt holds the string denoting the password_verified_at field in the database.
+ FieldPasswordVerifiedAt = "password_verified_at"
+ // FieldTotpVerifiedAt holds the string denoting the totp_verified_at field in the database.
+ FieldTotpVerifiedAt = "totp_verified_at"
+ // FieldExpiresAt holds the string denoting the expires_at field in the database.
+ FieldExpiresAt = "expires_at"
+ // FieldConsumedAt holds the string denoting the consumed_at field in the database.
+ FieldConsumedAt = "consumed_at"
+ // EdgeTargetUser holds the string denoting the target_user edge name in mutations.
+ EdgeTargetUser = "target_user"
+ // EdgeAdoptionDecision holds the string denoting the adoption_decision edge name in mutations.
+ EdgeAdoptionDecision = "adoption_decision"
+ // Table holds the table name of the pendingauthsession in the database.
+ Table = "pending_auth_sessions"
+ // TargetUserTable is the table that holds the target_user relation/edge.
+ TargetUserTable = "pending_auth_sessions"
+ // TargetUserInverseTable is the table name for the User entity.
+ // It exists in this package in order to avoid circular dependency with the "user" package.
+ TargetUserInverseTable = "users"
+ // TargetUserColumn is the table column denoting the target_user relation/edge.
+ TargetUserColumn = "target_user_id"
+ // AdoptionDecisionTable is the table that holds the adoption_decision relation/edge.
+ AdoptionDecisionTable = "identity_adoption_decisions"
+ // AdoptionDecisionInverseTable is the table name for the IdentityAdoptionDecision entity.
+ // It exists in this package in order to avoid circular dependency with the "identityadoptiondecision" package.
+ AdoptionDecisionInverseTable = "identity_adoption_decisions"
+ // AdoptionDecisionColumn is the table column denoting the adoption_decision relation/edge.
+ AdoptionDecisionColumn = "pending_auth_session_id"
+)
+
+// Columns holds all SQL columns for pendingauthsession fields.
+var Columns = []string{
+ FieldID,
+ FieldCreatedAt,
+ FieldUpdatedAt,
+ FieldSessionToken,
+ FieldIntent,
+ FieldProviderType,
+ FieldProviderKey,
+ FieldProviderSubject,
+ FieldTargetUserID,
+ FieldRedirectTo,
+ FieldResolvedEmail,
+ FieldRegistrationPasswordHash,
+ FieldUpstreamIdentityClaims,
+ FieldLocalFlowState,
+ FieldBrowserSessionKey,
+ FieldCompletionCodeHash,
+ FieldCompletionCodeExpiresAt,
+ FieldEmailVerifiedAt,
+ FieldPasswordVerifiedAt,
+ FieldTotpVerifiedAt,
+ FieldExpiresAt,
+ FieldConsumedAt,
+}
+
+// ValidColumn reports if the column name is valid (part of the table columns).
+func ValidColumn(column string) bool {
+ for i := range Columns {
+ if column == Columns[i] {
+ return true
+ }
+ }
+ return false
+}
+
+var (
+ // DefaultCreatedAt holds the default value on creation for the "created_at" field.
+ DefaultCreatedAt func() time.Time
+ // DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
+ DefaultUpdatedAt func() time.Time
+ // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
+ UpdateDefaultUpdatedAt func() time.Time
+ // SessionTokenValidator is a validator for the "session_token" field. It is called by the builders before save.
+ SessionTokenValidator func(string) error
+ // IntentValidator is a validator for the "intent" field. It is called by the builders before save.
+ IntentValidator func(string) error
+ // ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save.
+ ProviderTypeValidator func(string) error
+ // ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
+ ProviderKeyValidator func(string) error
+ // ProviderSubjectValidator is a validator for the "provider_subject" field. It is called by the builders before save.
+ ProviderSubjectValidator func(string) error
+ // DefaultRedirectTo holds the default value on creation for the "redirect_to" field.
+ DefaultRedirectTo string
+ // DefaultResolvedEmail holds the default value on creation for the "resolved_email" field.
+ DefaultResolvedEmail string
+ // DefaultRegistrationPasswordHash holds the default value on creation for the "registration_password_hash" field.
+ DefaultRegistrationPasswordHash string
+ // DefaultUpstreamIdentityClaims holds the default value on creation for the "upstream_identity_claims" field.
+ DefaultUpstreamIdentityClaims func() map[string]interface{}
+ // DefaultLocalFlowState holds the default value on creation for the "local_flow_state" field.
+ DefaultLocalFlowState func() map[string]interface{}
+ // DefaultBrowserSessionKey holds the default value on creation for the "browser_session_key" field.
+ DefaultBrowserSessionKey string
+ // DefaultCompletionCodeHash holds the default value on creation for the "completion_code_hash" field.
+ DefaultCompletionCodeHash string
+)
+
+// OrderOption defines the ordering options for the PendingAuthSession queries.
+type OrderOption func(*sql.Selector)
+
+// ByID orders the results by the id field.
+func ByID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldID, opts...).ToFunc()
+}
+
+// ByCreatedAt orders the results by the created_at field.
+func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
+}
+
+// ByUpdatedAt orders the results by the updated_at field.
+func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
+}
+
+// BySessionToken orders the results by the session_token field.
+func BySessionToken(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldSessionToken, opts...).ToFunc()
+}
+
+// ByIntent orders the results by the intent field.
+func ByIntent(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldIntent, opts...).ToFunc()
+}
+
+// ByProviderType orders the results by the provider_type field.
+func ByProviderType(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProviderType, opts...).ToFunc()
+}
+
+// ByProviderKey orders the results by the provider_key field.
+func ByProviderKey(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProviderKey, opts...).ToFunc()
+}
+
+// ByProviderSubject orders the results by the provider_subject field.
+func ByProviderSubject(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProviderSubject, opts...).ToFunc()
+}
+
+// ByTargetUserID orders the results by the target_user_id field.
+func ByTargetUserID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldTargetUserID, opts...).ToFunc()
+}
+
+// ByRedirectTo orders the results by the redirect_to field.
+func ByRedirectTo(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldRedirectTo, opts...).ToFunc()
+}
+
+// ByResolvedEmail orders the results by the resolved_email field.
+func ByResolvedEmail(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldResolvedEmail, opts...).ToFunc()
+}
+
+// ByRegistrationPasswordHash orders the results by the registration_password_hash field.
+func ByRegistrationPasswordHash(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldRegistrationPasswordHash, opts...).ToFunc()
+}
+
+// ByBrowserSessionKey orders the results by the browser_session_key field.
+func ByBrowserSessionKey(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldBrowserSessionKey, opts...).ToFunc()
+}
+
+// ByCompletionCodeHash orders the results by the completion_code_hash field.
+func ByCompletionCodeHash(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCompletionCodeHash, opts...).ToFunc()
+}
+
+// ByCompletionCodeExpiresAt orders the results by the completion_code_expires_at field.
+func ByCompletionCodeExpiresAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCompletionCodeExpiresAt, opts...).ToFunc()
+}
+
+// ByEmailVerifiedAt orders the results by the email_verified_at field.
+func ByEmailVerifiedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldEmailVerifiedAt, opts...).ToFunc()
+}
+
+// ByPasswordVerifiedAt orders the results by the password_verified_at field.
+func ByPasswordVerifiedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldPasswordVerifiedAt, opts...).ToFunc()
+}
+
+// ByTotpVerifiedAt orders the results by the totp_verified_at field.
+func ByTotpVerifiedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldTotpVerifiedAt, opts...).ToFunc()
+}
+
+// ByExpiresAt orders the results by the expires_at field.
+func ByExpiresAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldExpiresAt, opts...).ToFunc()
+}
+
+// ByConsumedAt orders the results by the consumed_at field.
+func ByConsumedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldConsumedAt, opts...).ToFunc()
+}
+
+// ByTargetUserField orders the results by target_user field.
+func ByTargetUserField(field string, opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newTargetUserStep(), sql.OrderByField(field, opts...))
+ }
+}
+
+// ByAdoptionDecisionField orders the results by adoption_decision field.
+func ByAdoptionDecisionField(field string, opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newAdoptionDecisionStep(), sql.OrderByField(field, opts...))
+ }
+}
+func newTargetUserStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(TargetUserInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, TargetUserTable, TargetUserColumn),
+ )
+}
+func newAdoptionDecisionStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(AdoptionDecisionInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.O2O, false, AdoptionDecisionTable, AdoptionDecisionColumn),
+ )
+}
diff --git a/backend/ent/pendingauthsession/where.go b/backend/ent/pendingauthsession/where.go
new file mode 100644
index 00000000..cb316f47
--- /dev/null
+++ b/backend/ent/pendingauthsession/where.go
@@ -0,0 +1,1262 @@
+// Code generated by ent, DO NOT EDIT.
+
+package pendingauthsession
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ID filters vertices based on their ID field.
+func ID(id int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldID, id))
+}
+
+// IDEQ applies the EQ predicate on the ID field.
+func IDEQ(id int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldID, id))
+}
+
+// IDNEQ applies the NEQ predicate on the ID field.
+func IDNEQ(id int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldID, id))
+}
+
+// IDIn applies the In predicate on the ID field.
+func IDIn(ids ...int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldID, ids...))
+}
+
+// IDNotIn applies the NotIn predicate on the ID field.
+func IDNotIn(ids ...int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldID, ids...))
+}
+
+// IDGT applies the GT predicate on the ID field.
+func IDGT(id int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldID, id))
+}
+
+// IDGTE applies the GTE predicate on the ID field.
+func IDGTE(id int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldID, id))
+}
+
+// IDLT applies the LT predicate on the ID field.
+func IDLT(id int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldID, id))
+}
+
+// IDLTE applies the LTE predicate on the ID field.
+func IDLTE(id int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldID, id))
+}
+
+// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
+func CreatedAt(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
+func UpdatedAt(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// SessionToken applies equality check predicate on the "session_token" field. It's identical to SessionTokenEQ.
+func SessionToken(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldSessionToken, v))
+}
+
+// Intent applies equality check predicate on the "intent" field. It's identical to IntentEQ.
+func Intent(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldIntent, v))
+}
+
+// ProviderType applies equality check predicate on the "provider_type" field. It's identical to ProviderTypeEQ.
+func ProviderType(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderType, v))
+}
+
+// ProviderKey applies equality check predicate on the "provider_key" field. It's identical to ProviderKeyEQ.
+func ProviderKey(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderKey, v))
+}
+
+// ProviderSubject applies equality check predicate on the "provider_subject" field. It's identical to ProviderSubjectEQ.
+func ProviderSubject(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderSubject, v))
+}
+
+// TargetUserID applies equality check predicate on the "target_user_id" field. It's identical to TargetUserIDEQ.
+func TargetUserID(v int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldTargetUserID, v))
+}
+
+// RedirectTo applies equality check predicate on the "redirect_to" field. It's identical to RedirectToEQ.
+func RedirectTo(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldRedirectTo, v))
+}
+
+// ResolvedEmail applies equality check predicate on the "resolved_email" field. It's identical to ResolvedEmailEQ.
+func ResolvedEmail(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldResolvedEmail, v))
+}
+
+// RegistrationPasswordHash applies equality check predicate on the "registration_password_hash" field. It's identical to RegistrationPasswordHashEQ.
+func RegistrationPasswordHash(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldRegistrationPasswordHash, v))
+}
+
+// BrowserSessionKey applies equality check predicate on the "browser_session_key" field. It's identical to BrowserSessionKeyEQ.
+func BrowserSessionKey(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldBrowserSessionKey, v))
+}
+
+// CompletionCodeHash applies equality check predicate on the "completion_code_hash" field. It's identical to CompletionCodeHashEQ.
+func CompletionCodeHash(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeExpiresAt applies equality check predicate on the "completion_code_expires_at" field. It's identical to CompletionCodeExpiresAtEQ.
+func CompletionCodeExpiresAt(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldCompletionCodeExpiresAt, v))
+}
+
+// EmailVerifiedAt applies equality check predicate on the "email_verified_at" field. It's identical to EmailVerifiedAtEQ.
+func EmailVerifiedAt(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldEmailVerifiedAt, v))
+}
+
+// PasswordVerifiedAt applies equality check predicate on the "password_verified_at" field. It's identical to PasswordVerifiedAtEQ.
+func PasswordVerifiedAt(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldPasswordVerifiedAt, v))
+}
+
+// TotpVerifiedAt applies equality check predicate on the "totp_verified_at" field. It's identical to TotpVerifiedAtEQ.
+func TotpVerifiedAt(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldTotpVerifiedAt, v))
+}
+
+// ExpiresAt applies equality check predicate on the "expires_at" field. It's identical to ExpiresAtEQ.
+func ExpiresAt(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldExpiresAt, v))
+}
+
+// ConsumedAt applies equality check predicate on the "consumed_at" field. It's identical to ConsumedAtEQ.
+func ConsumedAt(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldConsumedAt, v))
+}
+
+// CreatedAtEQ applies the EQ predicate on the "created_at" field.
+func CreatedAtEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
+func CreatedAtNEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtIn applies the In predicate on the "created_at" field.
+func CreatedAtIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
+func CreatedAtNotIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtGT applies the GT predicate on the "created_at" field.
+func CreatedAtGT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldCreatedAt, v))
+}
+
+// CreatedAtGTE applies the GTE predicate on the "created_at" field.
+func CreatedAtGTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldCreatedAt, v))
+}
+
+// CreatedAtLT applies the LT predicate on the "created_at" field.
+func CreatedAtLT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldCreatedAt, v))
+}
+
+// CreatedAtLTE applies the LTE predicate on the "created_at" field.
+func CreatedAtLTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldCreatedAt, v))
+}
+
+// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
+func UpdatedAtEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
+func UpdatedAtNEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtIn applies the In predicate on the "updated_at" field.
+func UpdatedAtIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
+func UpdatedAtNotIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtGT applies the GT predicate on the "updated_at" field.
+func UpdatedAtGT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
+func UpdatedAtGTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLT applies the LT predicate on the "updated_at" field.
+func UpdatedAtLT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
+func UpdatedAtLTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldUpdatedAt, v))
+}
+
+// SessionTokenEQ applies the EQ predicate on the "session_token" field.
+func SessionTokenEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldSessionToken, v))
+}
+
+// SessionTokenNEQ applies the NEQ predicate on the "session_token" field.
+func SessionTokenNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldSessionToken, v))
+}
+
+// SessionTokenIn applies the In predicate on the "session_token" field.
+func SessionTokenIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldSessionToken, vs...))
+}
+
+// SessionTokenNotIn applies the NotIn predicate on the "session_token" field.
+func SessionTokenNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldSessionToken, vs...))
+}
+
+// SessionTokenGT applies the GT predicate on the "session_token" field.
+func SessionTokenGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldSessionToken, v))
+}
+
+// SessionTokenGTE applies the GTE predicate on the "session_token" field.
+func SessionTokenGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldSessionToken, v))
+}
+
+// SessionTokenLT applies the LT predicate on the "session_token" field.
+func SessionTokenLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldSessionToken, v))
+}
+
+// SessionTokenLTE applies the LTE predicate on the "session_token" field.
+func SessionTokenLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldSessionToken, v))
+}
+
+// SessionTokenContains applies the Contains predicate on the "session_token" field.
+func SessionTokenContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldSessionToken, v))
+}
+
+// SessionTokenHasPrefix applies the HasPrefix predicate on the "session_token" field.
+func SessionTokenHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldSessionToken, v))
+}
+
+// SessionTokenHasSuffix applies the HasSuffix predicate on the "session_token" field.
+func SessionTokenHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldSessionToken, v))
+}
+
+// SessionTokenEqualFold applies the EqualFold predicate on the "session_token" field.
+func SessionTokenEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldSessionToken, v))
+}
+
+// SessionTokenContainsFold applies the ContainsFold predicate on the "session_token" field.
+func SessionTokenContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldSessionToken, v))
+}
+
+// IntentEQ applies the EQ predicate on the "intent" field.
+func IntentEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldIntent, v))
+}
+
+// IntentNEQ applies the NEQ predicate on the "intent" field.
+func IntentNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldIntent, v))
+}
+
+// IntentIn applies the In predicate on the "intent" field.
+func IntentIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldIntent, vs...))
+}
+
+// IntentNotIn applies the NotIn predicate on the "intent" field.
+func IntentNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldIntent, vs...))
+}
+
+// IntentGT applies the GT predicate on the "intent" field.
+func IntentGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldIntent, v))
+}
+
+// IntentGTE applies the GTE predicate on the "intent" field.
+func IntentGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldIntent, v))
+}
+
+// IntentLT applies the LT predicate on the "intent" field.
+func IntentLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldIntent, v))
+}
+
+// IntentLTE applies the LTE predicate on the "intent" field.
+func IntentLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldIntent, v))
+}
+
+// IntentContains applies the Contains predicate on the "intent" field.
+func IntentContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldIntent, v))
+}
+
+// IntentHasPrefix applies the HasPrefix predicate on the "intent" field.
+func IntentHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldIntent, v))
+}
+
+// IntentHasSuffix applies the HasSuffix predicate on the "intent" field.
+func IntentHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldIntent, v))
+}
+
+// IntentEqualFold applies the EqualFold predicate on the "intent" field.
+func IntentEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldIntent, v))
+}
+
+// IntentContainsFold applies the ContainsFold predicate on the "intent" field.
+func IntentContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldIntent, v))
+}
+
+// ProviderTypeEQ applies the EQ predicate on the "provider_type" field.
+func ProviderTypeEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderType, v))
+}
+
+// ProviderTypeNEQ applies the NEQ predicate on the "provider_type" field.
+func ProviderTypeNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldProviderType, v))
+}
+
+// ProviderTypeIn applies the In predicate on the "provider_type" field.
+func ProviderTypeIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldProviderType, vs...))
+}
+
+// ProviderTypeNotIn applies the NotIn predicate on the "provider_type" field.
+func ProviderTypeNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldProviderType, vs...))
+}
+
+// ProviderTypeGT applies the GT predicate on the "provider_type" field.
+func ProviderTypeGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldProviderType, v))
+}
+
+// ProviderTypeGTE applies the GTE predicate on the "provider_type" field.
+func ProviderTypeGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldProviderType, v))
+}
+
+// ProviderTypeLT applies the LT predicate on the "provider_type" field.
+func ProviderTypeLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldProviderType, v))
+}
+
+// ProviderTypeLTE applies the LTE predicate on the "provider_type" field.
+func ProviderTypeLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldProviderType, v))
+}
+
+// ProviderTypeContains applies the Contains predicate on the "provider_type" field.
+func ProviderTypeContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldProviderType, v))
+}
+
+// ProviderTypeHasPrefix applies the HasPrefix predicate on the "provider_type" field.
+func ProviderTypeHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldProviderType, v))
+}
+
+// ProviderTypeHasSuffix applies the HasSuffix predicate on the "provider_type" field.
+func ProviderTypeHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldProviderType, v))
+}
+
+// ProviderTypeEqualFold applies the EqualFold predicate on the "provider_type" field.
+func ProviderTypeEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldProviderType, v))
+}
+
+// ProviderTypeContainsFold applies the ContainsFold predicate on the "provider_type" field.
+func ProviderTypeContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldProviderType, v))
+}
+
+// ProviderKeyEQ applies the EQ predicate on the "provider_key" field.
+func ProviderKeyEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderKey, v))
+}
+
+// ProviderKeyNEQ applies the NEQ predicate on the "provider_key" field.
+func ProviderKeyNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldProviderKey, v))
+}
+
+// ProviderKeyIn applies the In predicate on the "provider_key" field.
+func ProviderKeyIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldProviderKey, vs...))
+}
+
+// ProviderKeyNotIn applies the NotIn predicate on the "provider_key" field.
+func ProviderKeyNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldProviderKey, vs...))
+}
+
+// ProviderKeyGT applies the GT predicate on the "provider_key" field.
+func ProviderKeyGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldProviderKey, v))
+}
+
+// ProviderKeyGTE applies the GTE predicate on the "provider_key" field.
+func ProviderKeyGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldProviderKey, v))
+}
+
+// ProviderKeyLT applies the LT predicate on the "provider_key" field.
+func ProviderKeyLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldProviderKey, v))
+}
+
+// ProviderKeyLTE applies the LTE predicate on the "provider_key" field.
+func ProviderKeyLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldProviderKey, v))
+}
+
+// ProviderKeyContains applies the Contains predicate on the "provider_key" field.
+func ProviderKeyContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldProviderKey, v))
+}
+
+// ProviderKeyHasPrefix applies the HasPrefix predicate on the "provider_key" field.
+func ProviderKeyHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldProviderKey, v))
+}
+
+// ProviderKeyHasSuffix applies the HasSuffix predicate on the "provider_key" field.
+func ProviderKeyHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldProviderKey, v))
+}
+
+// ProviderKeyEqualFold applies the EqualFold predicate on the "provider_key" field.
+func ProviderKeyEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldProviderKey, v))
+}
+
+// ProviderKeyContainsFold applies the ContainsFold predicate on the "provider_key" field.
+func ProviderKeyContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldProviderKey, v))
+}
+
+// ProviderSubjectEQ applies the EQ predicate on the "provider_subject" field.
+func ProviderSubjectEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderSubject, v))
+}
+
+// ProviderSubjectNEQ applies the NEQ predicate on the "provider_subject" field.
+func ProviderSubjectNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldProviderSubject, v))
+}
+
+// ProviderSubjectIn applies the In predicate on the "provider_subject" field.
+func ProviderSubjectIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldProviderSubject, vs...))
+}
+
+// ProviderSubjectNotIn applies the NotIn predicate on the "provider_subject" field.
+func ProviderSubjectNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldProviderSubject, vs...))
+}
+
+// ProviderSubjectGT applies the GT predicate on the "provider_subject" field.
+func ProviderSubjectGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldProviderSubject, v))
+}
+
+// ProviderSubjectGTE applies the GTE predicate on the "provider_subject" field.
+func ProviderSubjectGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldProviderSubject, v))
+}
+
+// ProviderSubjectLT applies the LT predicate on the "provider_subject" field.
+func ProviderSubjectLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldProviderSubject, v))
+}
+
+// ProviderSubjectLTE applies the LTE predicate on the "provider_subject" field.
+func ProviderSubjectLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldProviderSubject, v))
+}
+
+// ProviderSubjectContains applies the Contains predicate on the "provider_subject" field.
+func ProviderSubjectContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldProviderSubject, v))
+}
+
+// ProviderSubjectHasPrefix applies the HasPrefix predicate on the "provider_subject" field.
+func ProviderSubjectHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldProviderSubject, v))
+}
+
+// ProviderSubjectHasSuffix applies the HasSuffix predicate on the "provider_subject" field.
+func ProviderSubjectHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldProviderSubject, v))
+}
+
+// ProviderSubjectEqualFold applies the EqualFold predicate on the "provider_subject" field.
+func ProviderSubjectEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldProviderSubject, v))
+}
+
+// ProviderSubjectContainsFold applies the ContainsFold predicate on the "provider_subject" field.
+func ProviderSubjectContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldProviderSubject, v))
+}
+
+// TargetUserIDEQ applies the EQ predicate on the "target_user_id" field.
+func TargetUserIDEQ(v int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldTargetUserID, v))
+}
+
+// TargetUserIDNEQ applies the NEQ predicate on the "target_user_id" field.
+func TargetUserIDNEQ(v int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldTargetUserID, v))
+}
+
+// TargetUserIDIn applies the In predicate on the "target_user_id" field.
+func TargetUserIDIn(vs ...int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldTargetUserID, vs...))
+}
+
+// TargetUserIDNotIn applies the NotIn predicate on the "target_user_id" field.
+func TargetUserIDNotIn(vs ...int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldTargetUserID, vs...))
+}
+
+// TargetUserIDIsNil applies the IsNil predicate on the "target_user_id" field.
+func TargetUserIDIsNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIsNull(FieldTargetUserID))
+}
+
+// TargetUserIDNotNil applies the NotNil predicate on the "target_user_id" field.
+func TargetUserIDNotNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotNull(FieldTargetUserID))
+}
+
+// RedirectToEQ applies the EQ predicate on the "redirect_to" field.
+func RedirectToEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldRedirectTo, v))
+}
+
+// RedirectToNEQ applies the NEQ predicate on the "redirect_to" field.
+func RedirectToNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldRedirectTo, v))
+}
+
+// RedirectToIn applies the In predicate on the "redirect_to" field.
+func RedirectToIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldRedirectTo, vs...))
+}
+
+// RedirectToNotIn applies the NotIn predicate on the "redirect_to" field.
+func RedirectToNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldRedirectTo, vs...))
+}
+
+// RedirectToGT applies the GT predicate on the "redirect_to" field.
+func RedirectToGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldRedirectTo, v))
+}
+
+// RedirectToGTE applies the GTE predicate on the "redirect_to" field.
+func RedirectToGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldRedirectTo, v))
+}
+
+// RedirectToLT applies the LT predicate on the "redirect_to" field.
+func RedirectToLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldRedirectTo, v))
+}
+
+// RedirectToLTE applies the LTE predicate on the "redirect_to" field.
+func RedirectToLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldRedirectTo, v))
+}
+
+// RedirectToContains applies the Contains predicate on the "redirect_to" field.
+func RedirectToContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldRedirectTo, v))
+}
+
+// RedirectToHasPrefix applies the HasPrefix predicate on the "redirect_to" field.
+func RedirectToHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldRedirectTo, v))
+}
+
+// RedirectToHasSuffix applies the HasSuffix predicate on the "redirect_to" field.
+func RedirectToHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldRedirectTo, v))
+}
+
+// RedirectToEqualFold applies the EqualFold predicate on the "redirect_to" field.
+func RedirectToEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldRedirectTo, v))
+}
+
+// RedirectToContainsFold applies the ContainsFold predicate on the "redirect_to" field.
+func RedirectToContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldRedirectTo, v))
+}
+
+// ResolvedEmailEQ applies the EQ predicate on the "resolved_email" field.
+func ResolvedEmailEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailNEQ applies the NEQ predicate on the "resolved_email" field.
+func ResolvedEmailNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailIn applies the In predicate on the "resolved_email" field.
+func ResolvedEmailIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldResolvedEmail, vs...))
+}
+
+// ResolvedEmailNotIn applies the NotIn predicate on the "resolved_email" field.
+func ResolvedEmailNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldResolvedEmail, vs...))
+}
+
+// ResolvedEmailGT applies the GT predicate on the "resolved_email" field.
+func ResolvedEmailGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailGTE applies the GTE predicate on the "resolved_email" field.
+func ResolvedEmailGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailLT applies the LT predicate on the "resolved_email" field.
+func ResolvedEmailLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailLTE applies the LTE predicate on the "resolved_email" field.
+func ResolvedEmailLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailContains applies the Contains predicate on the "resolved_email" field.
+func ResolvedEmailContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailHasPrefix applies the HasPrefix predicate on the "resolved_email" field.
+func ResolvedEmailHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailHasSuffix applies the HasSuffix predicate on the "resolved_email" field.
+func ResolvedEmailHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailEqualFold applies the EqualFold predicate on the "resolved_email" field.
+func ResolvedEmailEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailContainsFold applies the ContainsFold predicate on the "resolved_email" field.
+func ResolvedEmailContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldResolvedEmail, v))
+}
+
+// RegistrationPasswordHashEQ applies the EQ predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashNEQ applies the NEQ predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashIn applies the In predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldRegistrationPasswordHash, vs...))
+}
+
+// RegistrationPasswordHashNotIn applies the NotIn predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldRegistrationPasswordHash, vs...))
+}
+
+// RegistrationPasswordHashGT applies the GT predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashGTE applies the GTE predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashLT applies the LT predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashLTE applies the LTE predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashContains applies the Contains predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashHasPrefix applies the HasPrefix predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashHasSuffix applies the HasSuffix predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashEqualFold applies the EqualFold predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashContainsFold applies the ContainsFold predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldRegistrationPasswordHash, v))
+}
+
+// BrowserSessionKeyEQ applies the EQ predicate on the "browser_session_key" field.
+func BrowserSessionKeyEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyNEQ applies the NEQ predicate on the "browser_session_key" field.
+func BrowserSessionKeyNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyIn applies the In predicate on the "browser_session_key" field.
+func BrowserSessionKeyIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldBrowserSessionKey, vs...))
+}
+
+// BrowserSessionKeyNotIn applies the NotIn predicate on the "browser_session_key" field.
+func BrowserSessionKeyNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldBrowserSessionKey, vs...))
+}
+
+// BrowserSessionKeyGT applies the GT predicate on the "browser_session_key" field.
+func BrowserSessionKeyGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyGTE applies the GTE predicate on the "browser_session_key" field.
+func BrowserSessionKeyGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyLT applies the LT predicate on the "browser_session_key" field.
+func BrowserSessionKeyLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyLTE applies the LTE predicate on the "browser_session_key" field.
+func BrowserSessionKeyLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyContains applies the Contains predicate on the "browser_session_key" field.
+func BrowserSessionKeyContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyHasPrefix applies the HasPrefix predicate on the "browser_session_key" field.
+func BrowserSessionKeyHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyHasSuffix applies the HasSuffix predicate on the "browser_session_key" field.
+func BrowserSessionKeyHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyEqualFold applies the EqualFold predicate on the "browser_session_key" field.
+func BrowserSessionKeyEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyContainsFold applies the ContainsFold predicate on the "browser_session_key" field.
+func BrowserSessionKeyContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldBrowserSessionKey, v))
+}
+
+// CompletionCodeHashEQ applies the EQ predicate on the "completion_code_hash" field.
+func CompletionCodeHashEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashNEQ applies the NEQ predicate on the "completion_code_hash" field.
+func CompletionCodeHashNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashIn applies the In predicate on the "completion_code_hash" field.
+func CompletionCodeHashIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldCompletionCodeHash, vs...))
+}
+
+// CompletionCodeHashNotIn applies the NotIn predicate on the "completion_code_hash" field.
+func CompletionCodeHashNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldCompletionCodeHash, vs...))
+}
+
+// CompletionCodeHashGT applies the GT predicate on the "completion_code_hash" field.
+func CompletionCodeHashGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashGTE applies the GTE predicate on the "completion_code_hash" field.
+func CompletionCodeHashGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashLT applies the LT predicate on the "completion_code_hash" field.
+func CompletionCodeHashLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashLTE applies the LTE predicate on the "completion_code_hash" field.
+func CompletionCodeHashLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashContains applies the Contains predicate on the "completion_code_hash" field.
+func CompletionCodeHashContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashHasPrefix applies the HasPrefix predicate on the "completion_code_hash" field.
+func CompletionCodeHashHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashHasSuffix applies the HasSuffix predicate on the "completion_code_hash" field.
+func CompletionCodeHashHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashEqualFold applies the EqualFold predicate on the "completion_code_hash" field.
+func CompletionCodeHashEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashContainsFold applies the ContainsFold predicate on the "completion_code_hash" field.
+func CompletionCodeHashContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeExpiresAtEQ applies the EQ predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldCompletionCodeExpiresAt, v))
+}
+
+// CompletionCodeExpiresAtNEQ applies the NEQ predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtNEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldCompletionCodeExpiresAt, v))
+}
+
+// CompletionCodeExpiresAtIn applies the In predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldCompletionCodeExpiresAt, vs...))
+}
+
+// CompletionCodeExpiresAtNotIn applies the NotIn predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtNotIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldCompletionCodeExpiresAt, vs...))
+}
+
+// CompletionCodeExpiresAtGT applies the GT predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtGT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldCompletionCodeExpiresAt, v))
+}
+
+// CompletionCodeExpiresAtGTE applies the GTE predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtGTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldCompletionCodeExpiresAt, v))
+}
+
+// CompletionCodeExpiresAtLT applies the LT predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtLT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldCompletionCodeExpiresAt, v))
+}
+
+// CompletionCodeExpiresAtLTE applies the LTE predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtLTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldCompletionCodeExpiresAt, v))
+}
+
+// CompletionCodeExpiresAtIsNil applies the IsNil predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtIsNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIsNull(FieldCompletionCodeExpiresAt))
+}
+
+// CompletionCodeExpiresAtNotNil applies the NotNil predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtNotNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotNull(FieldCompletionCodeExpiresAt))
+}
+
+// EmailVerifiedAtEQ applies the EQ predicate on the "email_verified_at" field.
+func EmailVerifiedAtEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldEmailVerifiedAt, v))
+}
+
+// EmailVerifiedAtNEQ applies the NEQ predicate on the "email_verified_at" field.
+func EmailVerifiedAtNEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldEmailVerifiedAt, v))
+}
+
+// EmailVerifiedAtIn applies the In predicate on the "email_verified_at" field.
+func EmailVerifiedAtIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldEmailVerifiedAt, vs...))
+}
+
+// EmailVerifiedAtNotIn applies the NotIn predicate on the "email_verified_at" field.
+func EmailVerifiedAtNotIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldEmailVerifiedAt, vs...))
+}
+
+// EmailVerifiedAtGT applies the GT predicate on the "email_verified_at" field.
+func EmailVerifiedAtGT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldEmailVerifiedAt, v))
+}
+
+// EmailVerifiedAtGTE applies the GTE predicate on the "email_verified_at" field.
+func EmailVerifiedAtGTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldEmailVerifiedAt, v))
+}
+
+// EmailVerifiedAtLT applies the LT predicate on the "email_verified_at" field.
+func EmailVerifiedAtLT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldEmailVerifiedAt, v))
+}
+
+// EmailVerifiedAtLTE applies the LTE predicate on the "email_verified_at" field.
+func EmailVerifiedAtLTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldEmailVerifiedAt, v))
+}
+
+// EmailVerifiedAtIsNil applies the IsNil predicate on the "email_verified_at" field.
+func EmailVerifiedAtIsNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIsNull(FieldEmailVerifiedAt))
+}
+
+// EmailVerifiedAtNotNil applies the NotNil predicate on the "email_verified_at" field.
+func EmailVerifiedAtNotNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotNull(FieldEmailVerifiedAt))
+}
+
+// PasswordVerifiedAtEQ applies the EQ predicate on the "password_verified_at" field.
+func PasswordVerifiedAtEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldPasswordVerifiedAt, v))
+}
+
+// PasswordVerifiedAtNEQ applies the NEQ predicate on the "password_verified_at" field.
+func PasswordVerifiedAtNEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldPasswordVerifiedAt, v))
+}
+
+// PasswordVerifiedAtIn applies the In predicate on the "password_verified_at" field.
+func PasswordVerifiedAtIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldPasswordVerifiedAt, vs...))
+}
+
+// PasswordVerifiedAtNotIn applies the NotIn predicate on the "password_verified_at" field.
+func PasswordVerifiedAtNotIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldPasswordVerifiedAt, vs...))
+}
+
+// PasswordVerifiedAtGT applies the GT predicate on the "password_verified_at" field.
+func PasswordVerifiedAtGT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldPasswordVerifiedAt, v))
+}
+
+// PasswordVerifiedAtGTE applies the GTE predicate on the "password_verified_at" field.
+func PasswordVerifiedAtGTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldPasswordVerifiedAt, v))
+}
+
+// PasswordVerifiedAtLT applies the LT predicate on the "password_verified_at" field.
+func PasswordVerifiedAtLT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldPasswordVerifiedAt, v))
+}
+
+// PasswordVerifiedAtLTE applies the LTE predicate on the "password_verified_at" field.
+func PasswordVerifiedAtLTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldPasswordVerifiedAt, v))
+}
+
+// PasswordVerifiedAtIsNil applies the IsNil predicate on the "password_verified_at" field.
+func PasswordVerifiedAtIsNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIsNull(FieldPasswordVerifiedAt))
+}
+
+// PasswordVerifiedAtNotNil applies the NotNil predicate on the "password_verified_at" field.
+func PasswordVerifiedAtNotNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotNull(FieldPasswordVerifiedAt))
+}
+
+// TotpVerifiedAtEQ applies the EQ predicate on the "totp_verified_at" field.
+func TotpVerifiedAtEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldTotpVerifiedAt, v))
+}
+
+// TotpVerifiedAtNEQ applies the NEQ predicate on the "totp_verified_at" field.
+func TotpVerifiedAtNEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldTotpVerifiedAt, v))
+}
+
+// TotpVerifiedAtIn applies the In predicate on the "totp_verified_at" field.
+func TotpVerifiedAtIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldTotpVerifiedAt, vs...))
+}
+
+// TotpVerifiedAtNotIn applies the NotIn predicate on the "totp_verified_at" field.
+func TotpVerifiedAtNotIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldTotpVerifiedAt, vs...))
+}
+
+// TotpVerifiedAtGT applies the GT predicate on the "totp_verified_at" field.
+func TotpVerifiedAtGT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldTotpVerifiedAt, v))
+}
+
+// TotpVerifiedAtGTE applies the GTE predicate on the "totp_verified_at" field.
+func TotpVerifiedAtGTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldTotpVerifiedAt, v))
+}
+
+// TotpVerifiedAtLT applies the LT predicate on the "totp_verified_at" field.
+func TotpVerifiedAtLT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldTotpVerifiedAt, v))
+}
+
+// TotpVerifiedAtLTE applies the LTE predicate on the "totp_verified_at" field.
+func TotpVerifiedAtLTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldTotpVerifiedAt, v))
+}
+
+// TotpVerifiedAtIsNil applies the IsNil predicate on the "totp_verified_at" field.
+func TotpVerifiedAtIsNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIsNull(FieldTotpVerifiedAt))
+}
+
+// TotpVerifiedAtNotNil applies the NotNil predicate on the "totp_verified_at" field.
+func TotpVerifiedAtNotNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotNull(FieldTotpVerifiedAt))
+}
+
+// ExpiresAtEQ applies the EQ predicate on the "expires_at" field.
+func ExpiresAtEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldExpiresAt, v))
+}
+
+// ExpiresAtNEQ applies the NEQ predicate on the "expires_at" field.
+func ExpiresAtNEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldExpiresAt, v))
+}
+
+// ExpiresAtIn applies the In predicate on the "expires_at" field.
+func ExpiresAtIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldExpiresAt, vs...))
+}
+
+// ExpiresAtNotIn applies the NotIn predicate on the "expires_at" field.
+func ExpiresAtNotIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldExpiresAt, vs...))
+}
+
+// ExpiresAtGT applies the GT predicate on the "expires_at" field.
+func ExpiresAtGT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldExpiresAt, v))
+}
+
+// ExpiresAtGTE applies the GTE predicate on the "expires_at" field.
+func ExpiresAtGTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldExpiresAt, v))
+}
+
+// ExpiresAtLT applies the LT predicate on the "expires_at" field.
+func ExpiresAtLT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldExpiresAt, v))
+}
+
+// ExpiresAtLTE applies the LTE predicate on the "expires_at" field.
+func ExpiresAtLTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldExpiresAt, v))
+}
+
+// ConsumedAtEQ applies the EQ predicate on the "consumed_at" field.
+func ConsumedAtEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldConsumedAt, v))
+}
+
+// ConsumedAtNEQ applies the NEQ predicate on the "consumed_at" field.
+func ConsumedAtNEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldConsumedAt, v))
+}
+
+// ConsumedAtIn applies the In predicate on the "consumed_at" field.
+func ConsumedAtIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldConsumedAt, vs...))
+}
+
+// ConsumedAtNotIn applies the NotIn predicate on the "consumed_at" field.
+func ConsumedAtNotIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldConsumedAt, vs...))
+}
+
+// ConsumedAtGT applies the GT predicate on the "consumed_at" field.
+func ConsumedAtGT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldConsumedAt, v))
+}
+
+// ConsumedAtGTE applies the GTE predicate on the "consumed_at" field.
+func ConsumedAtGTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldConsumedAt, v))
+}
+
+// ConsumedAtLT applies the LT predicate on the "consumed_at" field.
+func ConsumedAtLT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldConsumedAt, v))
+}
+
+// ConsumedAtLTE applies the LTE predicate on the "consumed_at" field.
+func ConsumedAtLTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldConsumedAt, v))
+}
+
+// ConsumedAtIsNil applies the IsNil predicate on the "consumed_at" field.
+func ConsumedAtIsNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIsNull(FieldConsumedAt))
+}
+
+// ConsumedAtNotNil applies the NotNil predicate on the "consumed_at" field.
+func ConsumedAtNotNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotNull(FieldConsumedAt))
+}
+
+// HasTargetUser applies the HasEdge predicate on the "target_user" edge.
+func HasTargetUser() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, TargetUserTable, TargetUserColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasTargetUserWith applies the HasEdge predicate on the "target_user" edge with a given conditions (other predicates).
+func HasTargetUserWith(preds ...predicate.User) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(func(s *sql.Selector) {
+ step := newTargetUserStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// HasAdoptionDecision applies the HasEdge predicate on the "adoption_decision" edge.
+func HasAdoptionDecision() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.O2O, false, AdoptionDecisionTable, AdoptionDecisionColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasAdoptionDecisionWith applies the HasEdge predicate on the "adoption_decision" edge with a given conditions (other predicates).
+func HasAdoptionDecisionWith(preds ...predicate.IdentityAdoptionDecision) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(func(s *sql.Selector) {
+ step := newAdoptionDecisionStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// And groups predicates with the AND operator between them.
+func And(predicates ...predicate.PendingAuthSession) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.AndPredicates(predicates...))
+}
+
+// Or groups predicates with the OR operator between them.
+func Or(predicates ...predicate.PendingAuthSession) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.OrPredicates(predicates...))
+}
+
+// Not applies the not operator on the given predicate.
+func Not(p predicate.PendingAuthSession) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.NotPredicates(p))
+}
diff --git a/backend/ent/pendingauthsession_create.go b/backend/ent/pendingauthsession_create.go
new file mode 100644
index 00000000..60276daa
--- /dev/null
+++ b/backend/ent/pendingauthsession_create.go
@@ -0,0 +1,1815 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// PendingAuthSessionCreate is the builder for creating a PendingAuthSession entity.
+type PendingAuthSessionCreate struct {
+ config
+ mutation *PendingAuthSessionMutation
+ hooks []Hook
+ conflict []sql.ConflictOption
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (_c *PendingAuthSessionCreate) SetCreatedAt(v time.Time) *PendingAuthSessionCreate {
+ _c.mutation.SetCreatedAt(v)
+ return _c
+}
+
+// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableCreatedAt(v *time.Time) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetCreatedAt(*v)
+ }
+ return _c
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_c *PendingAuthSessionCreate) SetUpdatedAt(v time.Time) *PendingAuthSessionCreate {
+ _c.mutation.SetUpdatedAt(v)
+ return _c
+}
+
+// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableUpdatedAt(v *time.Time) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetUpdatedAt(*v)
+ }
+ return _c
+}
+
+// SetSessionToken sets the "session_token" field.
+func (_c *PendingAuthSessionCreate) SetSessionToken(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetSessionToken(v)
+ return _c
+}
+
+// SetIntent sets the "intent" field.
+func (_c *PendingAuthSessionCreate) SetIntent(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetIntent(v)
+ return _c
+}
+
+// SetProviderType sets the "provider_type" field.
+func (_c *PendingAuthSessionCreate) SetProviderType(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetProviderType(v)
+ return _c
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_c *PendingAuthSessionCreate) SetProviderKey(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetProviderKey(v)
+ return _c
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (_c *PendingAuthSessionCreate) SetProviderSubject(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetProviderSubject(v)
+ return _c
+}
+
+// SetTargetUserID sets the "target_user_id" field.
+func (_c *PendingAuthSessionCreate) SetTargetUserID(v int64) *PendingAuthSessionCreate {
+ _c.mutation.SetTargetUserID(v)
+ return _c
+}
+
+// SetNillableTargetUserID sets the "target_user_id" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableTargetUserID(v *int64) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetTargetUserID(*v)
+ }
+ return _c
+}
+
+// SetRedirectTo sets the "redirect_to" field.
+func (_c *PendingAuthSessionCreate) SetRedirectTo(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetRedirectTo(v)
+ return _c
+}
+
+// SetNillableRedirectTo sets the "redirect_to" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableRedirectTo(v *string) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetRedirectTo(*v)
+ }
+ return _c
+}
+
+// SetResolvedEmail sets the "resolved_email" field.
+func (_c *PendingAuthSessionCreate) SetResolvedEmail(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetResolvedEmail(v)
+ return _c
+}
+
+// SetNillableResolvedEmail sets the "resolved_email" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableResolvedEmail(v *string) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetResolvedEmail(*v)
+ }
+ return _c
+}
+
+// SetRegistrationPasswordHash sets the "registration_password_hash" field.
+func (_c *PendingAuthSessionCreate) SetRegistrationPasswordHash(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetRegistrationPasswordHash(v)
+ return _c
+}
+
+// SetNillableRegistrationPasswordHash sets the "registration_password_hash" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableRegistrationPasswordHash(v *string) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetRegistrationPasswordHash(*v)
+ }
+ return _c
+}
+
+// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field.
+func (_c *PendingAuthSessionCreate) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionCreate {
+ _c.mutation.SetUpstreamIdentityClaims(v)
+ return _c
+}
+
+// SetLocalFlowState sets the "local_flow_state" field.
+func (_c *PendingAuthSessionCreate) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionCreate {
+ _c.mutation.SetLocalFlowState(v)
+ return _c
+}
+
+// SetBrowserSessionKey sets the "browser_session_key" field.
+func (_c *PendingAuthSessionCreate) SetBrowserSessionKey(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetBrowserSessionKey(v)
+ return _c
+}
+
+// SetNillableBrowserSessionKey sets the "browser_session_key" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableBrowserSessionKey(v *string) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetBrowserSessionKey(*v)
+ }
+ return _c
+}
+
+// SetCompletionCodeHash sets the "completion_code_hash" field.
+func (_c *PendingAuthSessionCreate) SetCompletionCodeHash(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetCompletionCodeHash(v)
+ return _c
+}
+
+// SetNillableCompletionCodeHash sets the "completion_code_hash" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableCompletionCodeHash(v *string) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetCompletionCodeHash(*v)
+ }
+ return _c
+}
+
+// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field.
+func (_c *PendingAuthSessionCreate) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionCreate {
+ _c.mutation.SetCompletionCodeExpiresAt(v)
+ return _c
+}
+
+// SetNillableCompletionCodeExpiresAt sets the "completion_code_expires_at" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableCompletionCodeExpiresAt(v *time.Time) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetCompletionCodeExpiresAt(*v)
+ }
+ return _c
+}
+
+// SetEmailVerifiedAt sets the "email_verified_at" field.
+func (_c *PendingAuthSessionCreate) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionCreate {
+ _c.mutation.SetEmailVerifiedAt(v)
+ return _c
+}
+
+// SetNillableEmailVerifiedAt sets the "email_verified_at" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableEmailVerifiedAt(v *time.Time) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetEmailVerifiedAt(*v)
+ }
+ return _c
+}
+
+// SetPasswordVerifiedAt sets the "password_verified_at" field.
+func (_c *PendingAuthSessionCreate) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionCreate {
+ _c.mutation.SetPasswordVerifiedAt(v)
+ return _c
+}
+
+// SetNillablePasswordVerifiedAt sets the "password_verified_at" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillablePasswordVerifiedAt(v *time.Time) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetPasswordVerifiedAt(*v)
+ }
+ return _c
+}
+
+// SetTotpVerifiedAt sets the "totp_verified_at" field.
+func (_c *PendingAuthSessionCreate) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionCreate {
+ _c.mutation.SetTotpVerifiedAt(v)
+ return _c
+}
+
+// SetNillableTotpVerifiedAt sets the "totp_verified_at" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableTotpVerifiedAt(v *time.Time) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetTotpVerifiedAt(*v)
+ }
+ return _c
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (_c *PendingAuthSessionCreate) SetExpiresAt(v time.Time) *PendingAuthSessionCreate {
+ _c.mutation.SetExpiresAt(v)
+ return _c
+}
+
+// SetConsumedAt sets the "consumed_at" field.
+func (_c *PendingAuthSessionCreate) SetConsumedAt(v time.Time) *PendingAuthSessionCreate {
+ _c.mutation.SetConsumedAt(v)
+ return _c
+}
+
+// SetNillableConsumedAt sets the "consumed_at" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableConsumedAt(v *time.Time) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetConsumedAt(*v)
+ }
+ return _c
+}
+
+// SetTargetUser sets the "target_user" edge to the User entity.
+func (_c *PendingAuthSessionCreate) SetTargetUser(v *User) *PendingAuthSessionCreate {
+ return _c.SetTargetUserID(v.ID)
+}
+
+// SetAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID.
+func (_c *PendingAuthSessionCreate) SetAdoptionDecisionID(id int64) *PendingAuthSessionCreate {
+ _c.mutation.SetAdoptionDecisionID(id)
+ return _c
+}
+
+// SetNillableAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableAdoptionDecisionID(id *int64) *PendingAuthSessionCreate {
+ if id != nil {
+ _c = _c.SetAdoptionDecisionID(*id)
+ }
+ return _c
+}
+
+// SetAdoptionDecision sets the "adoption_decision" edge to the IdentityAdoptionDecision entity.
+func (_c *PendingAuthSessionCreate) SetAdoptionDecision(v *IdentityAdoptionDecision) *PendingAuthSessionCreate {
+ return _c.SetAdoptionDecisionID(v.ID)
+}
+
+// Mutation returns the PendingAuthSessionMutation object of the builder.
+func (_c *PendingAuthSessionCreate) Mutation() *PendingAuthSessionMutation {
+ return _c.mutation
+}
+
+// Save creates the PendingAuthSession in the database.
+func (_c *PendingAuthSessionCreate) Save(ctx context.Context) (*PendingAuthSession, error) {
+ _c.defaults()
+ return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
+}
+
+// SaveX calls Save and panics if Save returns an error.
+func (_c *PendingAuthSessionCreate) SaveX(ctx context.Context) *PendingAuthSession {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *PendingAuthSessionCreate) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *PendingAuthSessionCreate) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_c *PendingAuthSessionCreate) defaults() {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ v := pendingauthsession.DefaultCreatedAt()
+ _c.mutation.SetCreatedAt(v)
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ v := pendingauthsession.DefaultUpdatedAt()
+ _c.mutation.SetUpdatedAt(v)
+ }
+ if _, ok := _c.mutation.RedirectTo(); !ok {
+ v := pendingauthsession.DefaultRedirectTo
+ _c.mutation.SetRedirectTo(v)
+ }
+ if _, ok := _c.mutation.ResolvedEmail(); !ok {
+ v := pendingauthsession.DefaultResolvedEmail
+ _c.mutation.SetResolvedEmail(v)
+ }
+ if _, ok := _c.mutation.RegistrationPasswordHash(); !ok {
+ v := pendingauthsession.DefaultRegistrationPasswordHash
+ _c.mutation.SetRegistrationPasswordHash(v)
+ }
+ if _, ok := _c.mutation.UpstreamIdentityClaims(); !ok {
+ v := pendingauthsession.DefaultUpstreamIdentityClaims()
+ _c.mutation.SetUpstreamIdentityClaims(v)
+ }
+ if _, ok := _c.mutation.LocalFlowState(); !ok {
+ v := pendingauthsession.DefaultLocalFlowState()
+ _c.mutation.SetLocalFlowState(v)
+ }
+ if _, ok := _c.mutation.BrowserSessionKey(); !ok {
+ v := pendingauthsession.DefaultBrowserSessionKey
+ _c.mutation.SetBrowserSessionKey(v)
+ }
+ if _, ok := _c.mutation.CompletionCodeHash(); !ok {
+ v := pendingauthsession.DefaultCompletionCodeHash
+ _c.mutation.SetCompletionCodeHash(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_c *PendingAuthSessionCreate) check() error {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "PendingAuthSession.created_at"`)}
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "PendingAuthSession.updated_at"`)}
+ }
+ if _, ok := _c.mutation.SessionToken(); !ok {
+ return &ValidationError{Name: "session_token", err: errors.New(`ent: missing required field "PendingAuthSession.session_token"`)}
+ }
+ if v, ok := _c.mutation.SessionToken(); ok {
+ if err := pendingauthsession.SessionTokenValidator(v); err != nil {
+ return &ValidationError{Name: "session_token", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.session_token": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.Intent(); !ok {
+ return &ValidationError{Name: "intent", err: errors.New(`ent: missing required field "PendingAuthSession.intent"`)}
+ }
+ if v, ok := _c.mutation.Intent(); ok {
+ if err := pendingauthsession.IntentValidator(v); err != nil {
+ return &ValidationError{Name: "intent", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.intent": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ProviderType(); !ok {
+ return &ValidationError{Name: "provider_type", err: errors.New(`ent: missing required field "PendingAuthSession.provider_type"`)}
+ }
+ if v, ok := _c.mutation.ProviderType(); ok {
+ if err := pendingauthsession.ProviderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_type": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ProviderKey(); !ok {
+ return &ValidationError{Name: "provider_key", err: errors.New(`ent: missing required field "PendingAuthSession.provider_key"`)}
+ }
+ if v, ok := _c.mutation.ProviderKey(); ok {
+ if err := pendingauthsession.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_key": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ProviderSubject(); !ok {
+ return &ValidationError{Name: "provider_subject", err: errors.New(`ent: missing required field "PendingAuthSession.provider_subject"`)}
+ }
+ if v, ok := _c.mutation.ProviderSubject(); ok {
+ if err := pendingauthsession.ProviderSubjectValidator(v); err != nil {
+ return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_subject": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.RedirectTo(); !ok {
+ return &ValidationError{Name: "redirect_to", err: errors.New(`ent: missing required field "PendingAuthSession.redirect_to"`)}
+ }
+ if _, ok := _c.mutation.ResolvedEmail(); !ok {
+ return &ValidationError{Name: "resolved_email", err: errors.New(`ent: missing required field "PendingAuthSession.resolved_email"`)}
+ }
+ if _, ok := _c.mutation.RegistrationPasswordHash(); !ok {
+ return &ValidationError{Name: "registration_password_hash", err: errors.New(`ent: missing required field "PendingAuthSession.registration_password_hash"`)}
+ }
+ if _, ok := _c.mutation.UpstreamIdentityClaims(); !ok {
+ return &ValidationError{Name: "upstream_identity_claims", err: errors.New(`ent: missing required field "PendingAuthSession.upstream_identity_claims"`)}
+ }
+ if _, ok := _c.mutation.LocalFlowState(); !ok {
+ return &ValidationError{Name: "local_flow_state", err: errors.New(`ent: missing required field "PendingAuthSession.local_flow_state"`)}
+ }
+ if _, ok := _c.mutation.BrowserSessionKey(); !ok {
+ return &ValidationError{Name: "browser_session_key", err: errors.New(`ent: missing required field "PendingAuthSession.browser_session_key"`)}
+ }
+ if _, ok := _c.mutation.CompletionCodeHash(); !ok {
+ return &ValidationError{Name: "completion_code_hash", err: errors.New(`ent: missing required field "PendingAuthSession.completion_code_hash"`)}
+ }
+ if _, ok := _c.mutation.ExpiresAt(); !ok {
+ return &ValidationError{Name: "expires_at", err: errors.New(`ent: missing required field "PendingAuthSession.expires_at"`)}
+ }
+ return nil
+}
+
+func (_c *PendingAuthSessionCreate) sqlSave(ctx context.Context) (*PendingAuthSession, error) {
+ if err := _c.check(); err != nil {
+ return nil, err
+ }
+ _node, _spec := _c.createSpec()
+ if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ id := _spec.ID.Value.(int64)
+ _node.ID = int64(id)
+ _c.mutation.id = &_node.ID
+ _c.mutation.done = true
+ return _node, nil
+}
+
+func (_c *PendingAuthSessionCreate) createSpec() (*PendingAuthSession, *sqlgraph.CreateSpec) {
+ var (
+ _node = &PendingAuthSession{config: _c.config}
+ _spec = sqlgraph.NewCreateSpec(pendingauthsession.Table, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64))
+ )
+ _spec.OnConflict = _c.conflict
+ if value, ok := _c.mutation.CreatedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldCreatedAt, field.TypeTime, value)
+ _node.CreatedAt = value
+ }
+ if value, ok := _c.mutation.UpdatedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldUpdatedAt, field.TypeTime, value)
+ _node.UpdatedAt = value
+ }
+ if value, ok := _c.mutation.SessionToken(); ok {
+ _spec.SetField(pendingauthsession.FieldSessionToken, field.TypeString, value)
+ _node.SessionToken = value
+ }
+ if value, ok := _c.mutation.Intent(); ok {
+ _spec.SetField(pendingauthsession.FieldIntent, field.TypeString, value)
+ _node.Intent = value
+ }
+ if value, ok := _c.mutation.ProviderType(); ok {
+ _spec.SetField(pendingauthsession.FieldProviderType, field.TypeString, value)
+ _node.ProviderType = value
+ }
+ if value, ok := _c.mutation.ProviderKey(); ok {
+ _spec.SetField(pendingauthsession.FieldProviderKey, field.TypeString, value)
+ _node.ProviderKey = value
+ }
+ if value, ok := _c.mutation.ProviderSubject(); ok {
+ _spec.SetField(pendingauthsession.FieldProviderSubject, field.TypeString, value)
+ _node.ProviderSubject = value
+ }
+ if value, ok := _c.mutation.RedirectTo(); ok {
+ _spec.SetField(pendingauthsession.FieldRedirectTo, field.TypeString, value)
+ _node.RedirectTo = value
+ }
+ if value, ok := _c.mutation.ResolvedEmail(); ok {
+ _spec.SetField(pendingauthsession.FieldResolvedEmail, field.TypeString, value)
+ _node.ResolvedEmail = value
+ }
+ if value, ok := _c.mutation.RegistrationPasswordHash(); ok {
+ _spec.SetField(pendingauthsession.FieldRegistrationPasswordHash, field.TypeString, value)
+ _node.RegistrationPasswordHash = value
+ }
+ if value, ok := _c.mutation.UpstreamIdentityClaims(); ok {
+ _spec.SetField(pendingauthsession.FieldUpstreamIdentityClaims, field.TypeJSON, value)
+ _node.UpstreamIdentityClaims = value
+ }
+ if value, ok := _c.mutation.LocalFlowState(); ok {
+ _spec.SetField(pendingauthsession.FieldLocalFlowState, field.TypeJSON, value)
+ _node.LocalFlowState = value
+ }
+ if value, ok := _c.mutation.BrowserSessionKey(); ok {
+ _spec.SetField(pendingauthsession.FieldBrowserSessionKey, field.TypeString, value)
+ _node.BrowserSessionKey = value
+ }
+ if value, ok := _c.mutation.CompletionCodeHash(); ok {
+ _spec.SetField(pendingauthsession.FieldCompletionCodeHash, field.TypeString, value)
+ _node.CompletionCodeHash = value
+ }
+ if value, ok := _c.mutation.CompletionCodeExpiresAt(); ok {
+ _spec.SetField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime, value)
+ _node.CompletionCodeExpiresAt = &value
+ }
+ if value, ok := _c.mutation.EmailVerifiedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime, value)
+ _node.EmailVerifiedAt = &value
+ }
+ if value, ok := _c.mutation.PasswordVerifiedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime, value)
+ _node.PasswordVerifiedAt = &value
+ }
+ if value, ok := _c.mutation.TotpVerifiedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime, value)
+ _node.TotpVerifiedAt = &value
+ }
+ if value, ok := _c.mutation.ExpiresAt(); ok {
+ _spec.SetField(pendingauthsession.FieldExpiresAt, field.TypeTime, value)
+ _node.ExpiresAt = value
+ }
+ if value, ok := _c.mutation.ConsumedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldConsumedAt, field.TypeTime, value)
+ _node.ConsumedAt = &value
+ }
+ if nodes := _c.mutation.TargetUserIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: pendingauthsession.TargetUserTable,
+ Columns: []string{pendingauthsession.TargetUserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _node.TargetUserID = &nodes[0]
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ if nodes := _c.mutation.AdoptionDecisionIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: false,
+ Table: pendingauthsession.AdoptionDecisionTable,
+ Columns: []string{pendingauthsession.AdoptionDecisionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ return _node, _spec
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.PendingAuthSession.Create().
+// SetCreatedAt(v).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.PendingAuthSessionUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *PendingAuthSessionCreate) OnConflict(opts ...sql.ConflictOption) *PendingAuthSessionUpsertOne {
+ _c.conflict = opts
+ return &PendingAuthSessionUpsertOne{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.PendingAuthSession.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *PendingAuthSessionCreate) OnConflictColumns(columns ...string) *PendingAuthSessionUpsertOne {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &PendingAuthSessionUpsertOne{
+ create: _c,
+ }
+}
+
+type (
+ // PendingAuthSessionUpsertOne is the builder for "upsert"-ing
+ // one PendingAuthSession node.
+ PendingAuthSessionUpsertOne struct {
+ create *PendingAuthSessionCreate
+ }
+
+ // PendingAuthSessionUpsert is the "OnConflict" setter.
+ PendingAuthSessionUpsert struct {
+ *sql.UpdateSet
+ }
+)
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *PendingAuthSessionUpsert) SetUpdatedAt(v time.Time) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldUpdatedAt, v)
+ return u
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateUpdatedAt() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldUpdatedAt)
+ return u
+}
+
+// SetSessionToken sets the "session_token" field.
+func (u *PendingAuthSessionUpsert) SetSessionToken(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldSessionToken, v)
+ return u
+}
+
+// UpdateSessionToken sets the "session_token" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateSessionToken() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldSessionToken)
+ return u
+}
+
+// SetIntent sets the "intent" field.
+func (u *PendingAuthSessionUpsert) SetIntent(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldIntent, v)
+ return u
+}
+
+// UpdateIntent sets the "intent" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateIntent() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldIntent)
+ return u
+}
+
+// SetProviderType sets the "provider_type" field.
+func (u *PendingAuthSessionUpsert) SetProviderType(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldProviderType, v)
+ return u
+}
+
+// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateProviderType() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldProviderType)
+ return u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *PendingAuthSessionUpsert) SetProviderKey(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldProviderKey, v)
+ return u
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateProviderKey() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldProviderKey)
+ return u
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (u *PendingAuthSessionUpsert) SetProviderSubject(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldProviderSubject, v)
+ return u
+}
+
+// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateProviderSubject() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldProviderSubject)
+ return u
+}
+
+// SetTargetUserID sets the "target_user_id" field.
+func (u *PendingAuthSessionUpsert) SetTargetUserID(v int64) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldTargetUserID, v)
+ return u
+}
+
+// UpdateTargetUserID sets the "target_user_id" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateTargetUserID() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldTargetUserID)
+ return u
+}
+
+// ClearTargetUserID clears the value of the "target_user_id" field.
+func (u *PendingAuthSessionUpsert) ClearTargetUserID() *PendingAuthSessionUpsert {
+ u.SetNull(pendingauthsession.FieldTargetUserID)
+ return u
+}
+
+// SetRedirectTo sets the "redirect_to" field.
+func (u *PendingAuthSessionUpsert) SetRedirectTo(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldRedirectTo, v)
+ return u
+}
+
+// UpdateRedirectTo sets the "redirect_to" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateRedirectTo() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldRedirectTo)
+ return u
+}
+
+// SetResolvedEmail sets the "resolved_email" field.
+func (u *PendingAuthSessionUpsert) SetResolvedEmail(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldResolvedEmail, v)
+ return u
+}
+
+// UpdateResolvedEmail sets the "resolved_email" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateResolvedEmail() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldResolvedEmail)
+ return u
+}
+
+// SetRegistrationPasswordHash sets the "registration_password_hash" field.
+func (u *PendingAuthSessionUpsert) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldRegistrationPasswordHash, v)
+ return u
+}
+
+// UpdateRegistrationPasswordHash sets the "registration_password_hash" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateRegistrationPasswordHash() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldRegistrationPasswordHash)
+ return u
+}
+
+// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field.
+func (u *PendingAuthSessionUpsert) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldUpstreamIdentityClaims, v)
+ return u
+}
+
+// UpdateUpstreamIdentityClaims sets the "upstream_identity_claims" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateUpstreamIdentityClaims() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldUpstreamIdentityClaims)
+ return u
+}
+
+// SetLocalFlowState sets the "local_flow_state" field.
+func (u *PendingAuthSessionUpsert) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldLocalFlowState, v)
+ return u
+}
+
+// UpdateLocalFlowState sets the "local_flow_state" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateLocalFlowState() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldLocalFlowState)
+ return u
+}
+
+// SetBrowserSessionKey sets the "browser_session_key" field.
+func (u *PendingAuthSessionUpsert) SetBrowserSessionKey(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldBrowserSessionKey, v)
+ return u
+}
+
+// UpdateBrowserSessionKey sets the "browser_session_key" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateBrowserSessionKey() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldBrowserSessionKey)
+ return u
+}
+
+// SetCompletionCodeHash sets the "completion_code_hash" field.
+func (u *PendingAuthSessionUpsert) SetCompletionCodeHash(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldCompletionCodeHash, v)
+ return u
+}
+
+// UpdateCompletionCodeHash sets the "completion_code_hash" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateCompletionCodeHash() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldCompletionCodeHash)
+ return u
+}
+
+// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field.
+func (u *PendingAuthSessionUpsert) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldCompletionCodeExpiresAt, v)
+ return u
+}
+
+// UpdateCompletionCodeExpiresAt sets the "completion_code_expires_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateCompletionCodeExpiresAt() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldCompletionCodeExpiresAt)
+ return u
+}
+
+// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field.
+func (u *PendingAuthSessionUpsert) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpsert {
+ u.SetNull(pendingauthsession.FieldCompletionCodeExpiresAt)
+ return u
+}
+
+// SetEmailVerifiedAt sets the "email_verified_at" field.
+func (u *PendingAuthSessionUpsert) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldEmailVerifiedAt, v)
+ return u
+}
+
+// UpdateEmailVerifiedAt sets the "email_verified_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateEmailVerifiedAt() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldEmailVerifiedAt)
+ return u
+}
+
+// ClearEmailVerifiedAt clears the value of the "email_verified_at" field.
+func (u *PendingAuthSessionUpsert) ClearEmailVerifiedAt() *PendingAuthSessionUpsert {
+ u.SetNull(pendingauthsession.FieldEmailVerifiedAt)
+ return u
+}
+
+// SetPasswordVerifiedAt sets the "password_verified_at" field.
+func (u *PendingAuthSessionUpsert) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldPasswordVerifiedAt, v)
+ return u
+}
+
+// UpdatePasswordVerifiedAt sets the "password_verified_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdatePasswordVerifiedAt() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldPasswordVerifiedAt)
+ return u
+}
+
+// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field.
+func (u *PendingAuthSessionUpsert) ClearPasswordVerifiedAt() *PendingAuthSessionUpsert {
+ u.SetNull(pendingauthsession.FieldPasswordVerifiedAt)
+ return u
+}
+
+// SetTotpVerifiedAt sets the "totp_verified_at" field.
+func (u *PendingAuthSessionUpsert) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldTotpVerifiedAt, v)
+ return u
+}
+
+// UpdateTotpVerifiedAt sets the "totp_verified_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateTotpVerifiedAt() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldTotpVerifiedAt)
+ return u
+}
+
+// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field.
+func (u *PendingAuthSessionUpsert) ClearTotpVerifiedAt() *PendingAuthSessionUpsert {
+ u.SetNull(pendingauthsession.FieldTotpVerifiedAt)
+ return u
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (u *PendingAuthSessionUpsert) SetExpiresAt(v time.Time) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldExpiresAt, v)
+ return u
+}
+
+// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateExpiresAt() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldExpiresAt)
+ return u
+}
+
+// SetConsumedAt sets the "consumed_at" field.
+func (u *PendingAuthSessionUpsert) SetConsumedAt(v time.Time) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldConsumedAt, v)
+ return u
+}
+
+// UpdateConsumedAt sets the "consumed_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateConsumedAt() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldConsumedAt)
+ return u
+}
+
+// ClearConsumedAt clears the value of the "consumed_at" field.
+func (u *PendingAuthSessionUpsert) ClearConsumedAt() *PendingAuthSessionUpsert {
+ u.SetNull(pendingauthsession.FieldConsumedAt)
+ return u
+}
+
+// UpdateNewValues updates the mutable fields using the new values that were set on create.
+// Using this option is equivalent to using:
+//
+// client.PendingAuthSession.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *PendingAuthSessionUpsertOne) UpdateNewValues() *PendingAuthSessionUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ if _, exists := u.create.mutation.CreatedAt(); exists {
+ s.SetIgnore(pendingauthsession.FieldCreatedAt)
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.PendingAuthSession.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *PendingAuthSessionUpsertOne) Ignore() *PendingAuthSessionUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *PendingAuthSessionUpsertOne) DoNothing() *PendingAuthSessionUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the PendingAuthSessionCreate.OnConflict
+// documentation for more info.
+func (u *PendingAuthSessionUpsertOne) Update(set func(*PendingAuthSessionUpsert)) *PendingAuthSessionUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&PendingAuthSessionUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *PendingAuthSessionUpsertOne) SetUpdatedAt(v time.Time) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateUpdatedAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetSessionToken sets the "session_token" field.
+func (u *PendingAuthSessionUpsertOne) SetSessionToken(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetSessionToken(v)
+ })
+}
+
+// UpdateSessionToken sets the "session_token" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateSessionToken() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateSessionToken()
+ })
+}
+
+// SetIntent sets the "intent" field.
+func (u *PendingAuthSessionUpsertOne) SetIntent(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetIntent(v)
+ })
+}
+
+// UpdateIntent sets the "intent" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateIntent() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateIntent()
+ })
+}
+
+// SetProviderType sets the "provider_type" field.
+func (u *PendingAuthSessionUpsertOne) SetProviderType(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetProviderType(v)
+ })
+}
+
+// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateProviderType() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateProviderType()
+ })
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *PendingAuthSessionUpsertOne) SetProviderKey(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetProviderKey(v)
+ })
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateProviderKey() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateProviderKey()
+ })
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (u *PendingAuthSessionUpsertOne) SetProviderSubject(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetProviderSubject(v)
+ })
+}
+
+// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateProviderSubject() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateProviderSubject()
+ })
+}
+
+// SetTargetUserID sets the "target_user_id" field.
+func (u *PendingAuthSessionUpsertOne) SetTargetUserID(v int64) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetTargetUserID(v)
+ })
+}
+
+// UpdateTargetUserID sets the "target_user_id" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateTargetUserID() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateTargetUserID()
+ })
+}
+
+// ClearTargetUserID clears the value of the "target_user_id" field.
+func (u *PendingAuthSessionUpsertOne) ClearTargetUserID() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearTargetUserID()
+ })
+}
+
+// SetRedirectTo sets the "redirect_to" field.
+func (u *PendingAuthSessionUpsertOne) SetRedirectTo(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetRedirectTo(v)
+ })
+}
+
+// UpdateRedirectTo sets the "redirect_to" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateRedirectTo() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateRedirectTo()
+ })
+}
+
+// SetResolvedEmail sets the "resolved_email" field.
+func (u *PendingAuthSessionUpsertOne) SetResolvedEmail(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetResolvedEmail(v)
+ })
+}
+
+// UpdateResolvedEmail sets the "resolved_email" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateResolvedEmail() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateResolvedEmail()
+ })
+}
+
+// SetRegistrationPasswordHash sets the "registration_password_hash" field.
+func (u *PendingAuthSessionUpsertOne) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetRegistrationPasswordHash(v)
+ })
+}
+
+// UpdateRegistrationPasswordHash sets the "registration_password_hash" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateRegistrationPasswordHash() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateRegistrationPasswordHash()
+ })
+}
+
+// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field.
+func (u *PendingAuthSessionUpsertOne) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetUpstreamIdentityClaims(v)
+ })
+}
+
+// UpdateUpstreamIdentityClaims sets the "upstream_identity_claims" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateUpstreamIdentityClaims() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateUpstreamIdentityClaims()
+ })
+}
+
+// SetLocalFlowState sets the "local_flow_state" field.
+func (u *PendingAuthSessionUpsertOne) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetLocalFlowState(v)
+ })
+}
+
+// UpdateLocalFlowState sets the "local_flow_state" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateLocalFlowState() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateLocalFlowState()
+ })
+}
+
+// SetBrowserSessionKey sets the "browser_session_key" field.
+func (u *PendingAuthSessionUpsertOne) SetBrowserSessionKey(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetBrowserSessionKey(v)
+ })
+}
+
+// UpdateBrowserSessionKey sets the "browser_session_key" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateBrowserSessionKey() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateBrowserSessionKey()
+ })
+}
+
+// SetCompletionCodeHash sets the "completion_code_hash" field.
+func (u *PendingAuthSessionUpsertOne) SetCompletionCodeHash(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetCompletionCodeHash(v)
+ })
+}
+
+// UpdateCompletionCodeHash sets the "completion_code_hash" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateCompletionCodeHash() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateCompletionCodeHash()
+ })
+}
+
+// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field.
+func (u *PendingAuthSessionUpsertOne) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetCompletionCodeExpiresAt(v)
+ })
+}
+
+// UpdateCompletionCodeExpiresAt sets the "completion_code_expires_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateCompletionCodeExpiresAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateCompletionCodeExpiresAt()
+ })
+}
+
+// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field.
+func (u *PendingAuthSessionUpsertOne) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearCompletionCodeExpiresAt()
+ })
+}
+
+// SetEmailVerifiedAt sets the "email_verified_at" field.
+func (u *PendingAuthSessionUpsertOne) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetEmailVerifiedAt(v)
+ })
+}
+
+// UpdateEmailVerifiedAt sets the "email_verified_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateEmailVerifiedAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateEmailVerifiedAt()
+ })
+}
+
+// ClearEmailVerifiedAt clears the value of the "email_verified_at" field.
+func (u *PendingAuthSessionUpsertOne) ClearEmailVerifiedAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearEmailVerifiedAt()
+ })
+}
+
+// SetPasswordVerifiedAt sets the "password_verified_at" field.
+func (u *PendingAuthSessionUpsertOne) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetPasswordVerifiedAt(v)
+ })
+}
+
+// UpdatePasswordVerifiedAt sets the "password_verified_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdatePasswordVerifiedAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdatePasswordVerifiedAt()
+ })
+}
+
+// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field.
+func (u *PendingAuthSessionUpsertOne) ClearPasswordVerifiedAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearPasswordVerifiedAt()
+ })
+}
+
+// SetTotpVerifiedAt sets the "totp_verified_at" field.
+func (u *PendingAuthSessionUpsertOne) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetTotpVerifiedAt(v)
+ })
+}
+
+// UpdateTotpVerifiedAt sets the "totp_verified_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateTotpVerifiedAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateTotpVerifiedAt()
+ })
+}
+
+// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field.
+func (u *PendingAuthSessionUpsertOne) ClearTotpVerifiedAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearTotpVerifiedAt()
+ })
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (u *PendingAuthSessionUpsertOne) SetExpiresAt(v time.Time) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetExpiresAt(v)
+ })
+}
+
+// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateExpiresAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateExpiresAt()
+ })
+}
+
+// SetConsumedAt sets the "consumed_at" field.
+func (u *PendingAuthSessionUpsertOne) SetConsumedAt(v time.Time) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetConsumedAt(v)
+ })
+}
+
+// UpdateConsumedAt sets the "consumed_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateConsumedAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateConsumedAt()
+ })
+}
+
+// ClearConsumedAt clears the value of the "consumed_at" field.
+func (u *PendingAuthSessionUpsertOne) ClearConsumedAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearConsumedAt()
+ })
+}
+
+// Exec executes the query.
+func (u *PendingAuthSessionUpsertOne) Exec(ctx context.Context) error {
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for PendingAuthSessionCreate.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *PendingAuthSessionUpsertOne) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// Exec executes the UPSERT query and returns the inserted/updated ID.
+func (u *PendingAuthSessionUpsertOne) ID(ctx context.Context) (id int64, err error) {
+ node, err := u.create.Save(ctx)
+ if err != nil {
+ return id, err
+ }
+ return node.ID, nil
+}
+
+// IDX is like ID, but panics if an error occurs.
+func (u *PendingAuthSessionUpsertOne) IDX(ctx context.Context) int64 {
+ id, err := u.ID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// PendingAuthSessionCreateBulk is the builder for creating many PendingAuthSession entities in bulk.
+type PendingAuthSessionCreateBulk struct {
+ config
+ err error
+ builders []*PendingAuthSessionCreate
+ conflict []sql.ConflictOption
+}
+
+// Save creates the PendingAuthSession entities in the database.
+func (_c *PendingAuthSessionCreateBulk) Save(ctx context.Context) ([]*PendingAuthSession, error) {
+ if _c.err != nil {
+ return nil, _c.err
+ }
+ specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
+ nodes := make([]*PendingAuthSession, len(_c.builders))
+ mutators := make([]Mutator, len(_c.builders))
+ for i := range _c.builders {
+ func(i int, root context.Context) {
+ builder := _c.builders[i]
+ builder.defaults()
+ var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
+ mutation, ok := m.(*PendingAuthSessionMutation)
+ if !ok {
+ return nil, fmt.Errorf("unexpected mutation type %T", m)
+ }
+ if err := builder.check(); err != nil {
+ return nil, err
+ }
+ builder.mutation = mutation
+ var err error
+ nodes[i], specs[i] = builder.createSpec()
+ if i < len(mutators)-1 {
+ _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
+ } else {
+ spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
+ spec.OnConflict = _c.conflict
+ // Invoke the actual operation on the latest mutation in the chain.
+ if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ }
+ }
+ if err != nil {
+ return nil, err
+ }
+ mutation.id = &nodes[i].ID
+ if specs[i].ID.Value != nil {
+ id := specs[i].ID.Value.(int64)
+ nodes[i].ID = int64(id)
+ }
+ mutation.done = true
+ return nodes[i], nil
+ })
+ for i := len(builder.hooks) - 1; i >= 0; i-- {
+ mut = builder.hooks[i](mut)
+ }
+ mutators[i] = mut
+ }(i, ctx)
+ }
+ if len(mutators) > 0 {
+ if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_c *PendingAuthSessionCreateBulk) SaveX(ctx context.Context) []*PendingAuthSession {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *PendingAuthSessionCreateBulk) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *PendingAuthSessionCreateBulk) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.PendingAuthSession.CreateBulk(builders...).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.PendingAuthSessionUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *PendingAuthSessionCreateBulk) OnConflict(opts ...sql.ConflictOption) *PendingAuthSessionUpsertBulk {
+ _c.conflict = opts
+ return &PendingAuthSessionUpsertBulk{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.PendingAuthSession.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *PendingAuthSessionCreateBulk) OnConflictColumns(columns ...string) *PendingAuthSessionUpsertBulk {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &PendingAuthSessionUpsertBulk{
+ create: _c,
+ }
+}
+
+// PendingAuthSessionUpsertBulk is the builder for "upsert"-ing
+// a bulk of PendingAuthSession nodes.
+type PendingAuthSessionUpsertBulk struct {
+ create *PendingAuthSessionCreateBulk
+}
+
+// UpdateNewValues updates the mutable fields using the new values that
+// were set on create. Using this option is equivalent to using:
+//
+// client.PendingAuthSession.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *PendingAuthSessionUpsertBulk) UpdateNewValues() *PendingAuthSessionUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ for _, b := range u.create.builders {
+ if _, exists := b.mutation.CreatedAt(); exists {
+ s.SetIgnore(pendingauthsession.FieldCreatedAt)
+ }
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.PendingAuthSession.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *PendingAuthSessionUpsertBulk) Ignore() *PendingAuthSessionUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *PendingAuthSessionUpsertBulk) DoNothing() *PendingAuthSessionUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the PendingAuthSessionCreateBulk.OnConflict
+// documentation for more info.
+func (u *PendingAuthSessionUpsertBulk) Update(set func(*PendingAuthSessionUpsert)) *PendingAuthSessionUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&PendingAuthSessionUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *PendingAuthSessionUpsertBulk) SetUpdatedAt(v time.Time) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateUpdatedAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetSessionToken sets the "session_token" field.
+func (u *PendingAuthSessionUpsertBulk) SetSessionToken(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetSessionToken(v)
+ })
+}
+
+// UpdateSessionToken sets the "session_token" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateSessionToken() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateSessionToken()
+ })
+}
+
+// SetIntent sets the "intent" field.
+func (u *PendingAuthSessionUpsertBulk) SetIntent(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetIntent(v)
+ })
+}
+
+// UpdateIntent sets the "intent" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateIntent() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateIntent()
+ })
+}
+
+// SetProviderType sets the "provider_type" field.
+func (u *PendingAuthSessionUpsertBulk) SetProviderType(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetProviderType(v)
+ })
+}
+
+// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateProviderType() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateProviderType()
+ })
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *PendingAuthSessionUpsertBulk) SetProviderKey(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetProviderKey(v)
+ })
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateProviderKey() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateProviderKey()
+ })
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (u *PendingAuthSessionUpsertBulk) SetProviderSubject(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetProviderSubject(v)
+ })
+}
+
+// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateProviderSubject() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateProviderSubject()
+ })
+}
+
+// SetTargetUserID sets the "target_user_id" field.
+func (u *PendingAuthSessionUpsertBulk) SetTargetUserID(v int64) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetTargetUserID(v)
+ })
+}
+
+// UpdateTargetUserID sets the "target_user_id" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateTargetUserID() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateTargetUserID()
+ })
+}
+
+// ClearTargetUserID clears the value of the "target_user_id" field.
+func (u *PendingAuthSessionUpsertBulk) ClearTargetUserID() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearTargetUserID()
+ })
+}
+
+// SetRedirectTo sets the "redirect_to" field.
+func (u *PendingAuthSessionUpsertBulk) SetRedirectTo(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetRedirectTo(v)
+ })
+}
+
+// UpdateRedirectTo sets the "redirect_to" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateRedirectTo() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateRedirectTo()
+ })
+}
+
+// SetResolvedEmail sets the "resolved_email" field.
+func (u *PendingAuthSessionUpsertBulk) SetResolvedEmail(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetResolvedEmail(v)
+ })
+}
+
+// UpdateResolvedEmail sets the "resolved_email" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateResolvedEmail() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateResolvedEmail()
+ })
+}
+
+// SetRegistrationPasswordHash sets the "registration_password_hash" field.
+func (u *PendingAuthSessionUpsertBulk) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetRegistrationPasswordHash(v)
+ })
+}
+
+// UpdateRegistrationPasswordHash sets the "registration_password_hash" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateRegistrationPasswordHash() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateRegistrationPasswordHash()
+ })
+}
+
+// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field.
+func (u *PendingAuthSessionUpsertBulk) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetUpstreamIdentityClaims(v)
+ })
+}
+
+// UpdateUpstreamIdentityClaims sets the "upstream_identity_claims" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateUpstreamIdentityClaims() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateUpstreamIdentityClaims()
+ })
+}
+
+// SetLocalFlowState sets the "local_flow_state" field.
+func (u *PendingAuthSessionUpsertBulk) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetLocalFlowState(v)
+ })
+}
+
+// UpdateLocalFlowState sets the "local_flow_state" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateLocalFlowState() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateLocalFlowState()
+ })
+}
+
+// SetBrowserSessionKey sets the "browser_session_key" field.
+func (u *PendingAuthSessionUpsertBulk) SetBrowserSessionKey(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetBrowserSessionKey(v)
+ })
+}
+
+// UpdateBrowserSessionKey sets the "browser_session_key" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateBrowserSessionKey() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateBrowserSessionKey()
+ })
+}
+
+// SetCompletionCodeHash sets the "completion_code_hash" field.
+func (u *PendingAuthSessionUpsertBulk) SetCompletionCodeHash(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetCompletionCodeHash(v)
+ })
+}
+
+// UpdateCompletionCodeHash sets the "completion_code_hash" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateCompletionCodeHash() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateCompletionCodeHash()
+ })
+}
+
+// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field.
+func (u *PendingAuthSessionUpsertBulk) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetCompletionCodeExpiresAt(v)
+ })
+}
+
+// UpdateCompletionCodeExpiresAt sets the "completion_code_expires_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateCompletionCodeExpiresAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateCompletionCodeExpiresAt()
+ })
+}
+
+// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field.
+func (u *PendingAuthSessionUpsertBulk) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearCompletionCodeExpiresAt()
+ })
+}
+
+// SetEmailVerifiedAt sets the "email_verified_at" field.
+func (u *PendingAuthSessionUpsertBulk) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetEmailVerifiedAt(v)
+ })
+}
+
+// UpdateEmailVerifiedAt sets the "email_verified_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateEmailVerifiedAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateEmailVerifiedAt()
+ })
+}
+
+// ClearEmailVerifiedAt clears the value of the "email_verified_at" field.
+func (u *PendingAuthSessionUpsertBulk) ClearEmailVerifiedAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearEmailVerifiedAt()
+ })
+}
+
+// SetPasswordVerifiedAt sets the "password_verified_at" field.
+func (u *PendingAuthSessionUpsertBulk) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetPasswordVerifiedAt(v)
+ })
+}
+
+// UpdatePasswordVerifiedAt sets the "password_verified_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdatePasswordVerifiedAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdatePasswordVerifiedAt()
+ })
+}
+
+// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field.
+func (u *PendingAuthSessionUpsertBulk) ClearPasswordVerifiedAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearPasswordVerifiedAt()
+ })
+}
+
+// SetTotpVerifiedAt sets the "totp_verified_at" field.
+func (u *PendingAuthSessionUpsertBulk) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetTotpVerifiedAt(v)
+ })
+}
+
+// UpdateTotpVerifiedAt sets the "totp_verified_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateTotpVerifiedAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateTotpVerifiedAt()
+ })
+}
+
+// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field.
+func (u *PendingAuthSessionUpsertBulk) ClearTotpVerifiedAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearTotpVerifiedAt()
+ })
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (u *PendingAuthSessionUpsertBulk) SetExpiresAt(v time.Time) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetExpiresAt(v)
+ })
+}
+
+// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateExpiresAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateExpiresAt()
+ })
+}
+
+// SetConsumedAt sets the "consumed_at" field.
+func (u *PendingAuthSessionUpsertBulk) SetConsumedAt(v time.Time) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetConsumedAt(v)
+ })
+}
+
+// UpdateConsumedAt sets the "consumed_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateConsumedAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateConsumedAt()
+ })
+}
+
+// ClearConsumedAt clears the value of the "consumed_at" field.
+func (u *PendingAuthSessionUpsertBulk) ClearConsumedAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearConsumedAt()
+ })
+}
+
+// Exec executes the query.
+func (u *PendingAuthSessionUpsertBulk) Exec(ctx context.Context) error {
+ if u.create.err != nil {
+ return u.create.err
+ }
+ for i, b := range u.create.builders {
+ if len(b.conflict) != 0 {
+ return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the PendingAuthSessionCreateBulk instead", i)
+ }
+ }
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for PendingAuthSessionCreateBulk.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *PendingAuthSessionUpsertBulk) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/pendingauthsession_delete.go b/backend/ent/pendingauthsession_delete.go
new file mode 100644
index 00000000..ee4fe605
--- /dev/null
+++ b/backend/ent/pendingauthsession_delete.go
@@ -0,0 +1,88 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// PendingAuthSessionDelete is the builder for deleting a PendingAuthSession entity.
+type PendingAuthSessionDelete struct {
+ config
+ hooks []Hook
+ mutation *PendingAuthSessionMutation
+}
+
+// Where appends a list predicates to the PendingAuthSessionDelete builder.
+func (_d *PendingAuthSessionDelete) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionDelete {
+ _d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query and returns how many vertices were deleted.
+func (_d *PendingAuthSessionDelete) Exec(ctx context.Context) (int, error) {
+ return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *PendingAuthSessionDelete) ExecX(ctx context.Context) int {
+ n, err := _d.Exec(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return n
+}
+
+func (_d *PendingAuthSessionDelete) sqlExec(ctx context.Context) (int, error) {
+ _spec := sqlgraph.NewDeleteSpec(pendingauthsession.Table, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64))
+ if ps := _d.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
+ if err != nil && sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ _d.mutation.done = true
+ return affected, err
+}
+
+// PendingAuthSessionDeleteOne is the builder for deleting a single PendingAuthSession entity.
+type PendingAuthSessionDeleteOne struct {
+ _d *PendingAuthSessionDelete
+}
+
+// Where appends a list predicates to the PendingAuthSessionDelete builder.
+func (_d *PendingAuthSessionDeleteOne) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionDeleteOne {
+ _d._d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query.
+func (_d *PendingAuthSessionDeleteOne) Exec(ctx context.Context) error {
+ n, err := _d._d.Exec(ctx)
+ switch {
+ case err != nil:
+ return err
+ case n == 0:
+ return &NotFoundError{pendingauthsession.Label}
+ default:
+ return nil
+ }
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *PendingAuthSessionDeleteOne) ExecX(ctx context.Context) {
+ if err := _d.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/pendingauthsession_query.go b/backend/ent/pendingauthsession_query.go
new file mode 100644
index 00000000..78e29cd2
--- /dev/null
+++ b/backend/ent/pendingauthsession_query.go
@@ -0,0 +1,717 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "database/sql/driver"
+ "fmt"
+ "math"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// PendingAuthSessionQuery is the builder for querying PendingAuthSession entities.
+type PendingAuthSessionQuery struct {
+ config
+ ctx *QueryContext
+ order []pendingauthsession.OrderOption
+ inters []Interceptor
+ predicates []predicate.PendingAuthSession
+ withTargetUser *UserQuery
+ withAdoptionDecision *IdentityAdoptionDecisionQuery
+ modifiers []func(*sql.Selector)
+ // intermediate query (i.e. traversal path).
+ sql *sql.Selector
+ path func(context.Context) (*sql.Selector, error)
+}
+
+// Where adds a new predicate for the PendingAuthSessionQuery builder.
+func (_q *PendingAuthSessionQuery) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionQuery {
+ _q.predicates = append(_q.predicates, ps...)
+ return _q
+}
+
+// Limit the number of records to be returned by this query.
+func (_q *PendingAuthSessionQuery) Limit(limit int) *PendingAuthSessionQuery {
+ _q.ctx.Limit = &limit
+ return _q
+}
+
+// Offset to start from.
+func (_q *PendingAuthSessionQuery) Offset(offset int) *PendingAuthSessionQuery {
+ _q.ctx.Offset = &offset
+ return _q
+}
+
+// Unique configures the query builder to filter duplicate records on query.
+// By default, unique is set to true, and can be disabled using this method.
+func (_q *PendingAuthSessionQuery) Unique(unique bool) *PendingAuthSessionQuery {
+ _q.ctx.Unique = &unique
+ return _q
+}
+
+// Order specifies how the records should be ordered.
+func (_q *PendingAuthSessionQuery) Order(o ...pendingauthsession.OrderOption) *PendingAuthSessionQuery {
+ _q.order = append(_q.order, o...)
+ return _q
+}
+
+// QueryTargetUser chains the current query on the "target_user" edge.
+func (_q *PendingAuthSessionQuery) QueryTargetUser() *UserQuery {
+ query := (&UserClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(pendingauthsession.Table, pendingauthsession.FieldID, selector),
+ sqlgraph.To(user.Table, user.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, pendingauthsession.TargetUserTable, pendingauthsession.TargetUserColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// QueryAdoptionDecision chains the current query on the "adoption_decision" edge.
+func (_q *PendingAuthSessionQuery) QueryAdoptionDecision() *IdentityAdoptionDecisionQuery {
+ query := (&IdentityAdoptionDecisionClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(pendingauthsession.Table, pendingauthsession.FieldID, selector),
+ sqlgraph.To(identityadoptiondecision.Table, identityadoptiondecision.FieldID),
+ sqlgraph.Edge(sqlgraph.O2O, false, pendingauthsession.AdoptionDecisionTable, pendingauthsession.AdoptionDecisionColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// First returns the first PendingAuthSession entity from the query.
+// Returns a *NotFoundError when no PendingAuthSession was found.
+func (_q *PendingAuthSessionQuery) First(ctx context.Context) (*PendingAuthSession, error) {
+ nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
+ if err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nil, &NotFoundError{pendingauthsession.Label}
+ }
+ return nodes[0], nil
+}
+
+// FirstX is like First, but panics if an error occurs.
+func (_q *PendingAuthSessionQuery) FirstX(ctx context.Context) *PendingAuthSession {
+ node, err := _q.First(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return node
+}
+
+// FirstID returns the first PendingAuthSession ID from the query.
+// Returns a *NotFoundError when no PendingAuthSession ID was found.
+func (_q *PendingAuthSessionQuery) FirstID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
+ return
+ }
+ if len(ids) == 0 {
+ err = &NotFoundError{pendingauthsession.Label}
+ return
+ }
+ return ids[0], nil
+}
+
+// FirstIDX is like FirstID, but panics if an error occurs.
+func (_q *PendingAuthSessionQuery) FirstIDX(ctx context.Context) int64 {
+ id, err := _q.FirstID(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return id
+}
+
+// Only returns a single PendingAuthSession entity found by the query, ensuring it only returns one.
+// Returns a *NotSingularError when more than one PendingAuthSession entity is found.
+// Returns a *NotFoundError when no PendingAuthSession entities are found.
+func (_q *PendingAuthSessionQuery) Only(ctx context.Context) (*PendingAuthSession, error) {
+ nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
+ if err != nil {
+ return nil, err
+ }
+ switch len(nodes) {
+ case 1:
+ return nodes[0], nil
+ case 0:
+ return nil, &NotFoundError{pendingauthsession.Label}
+ default:
+ return nil, &NotSingularError{pendingauthsession.Label}
+ }
+}
+
+// OnlyX is like Only, but panics if an error occurs.
+func (_q *PendingAuthSessionQuery) OnlyX(ctx context.Context) *PendingAuthSession {
+ node, err := _q.Only(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// OnlyID is like Only, but returns the only PendingAuthSession ID in the query.
+// Returns a *NotSingularError when more than one PendingAuthSession ID is found.
+// Returns a *NotFoundError when no entities are found.
+func (_q *PendingAuthSessionQuery) OnlyID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
+ return
+ }
+ switch len(ids) {
+ case 1:
+ id = ids[0]
+ case 0:
+ err = &NotFoundError{pendingauthsession.Label}
+ default:
+ err = &NotSingularError{pendingauthsession.Label}
+ }
+ return
+}
+
+// OnlyIDX is like OnlyID, but panics if an error occurs.
+func (_q *PendingAuthSessionQuery) OnlyIDX(ctx context.Context) int64 {
+ id, err := _q.OnlyID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// All executes the query and returns a list of PendingAuthSessions.
+func (_q *PendingAuthSessionQuery) All(ctx context.Context) ([]*PendingAuthSession, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ qr := querierAll[[]*PendingAuthSession, *PendingAuthSessionQuery]()
+ return withInterceptors[[]*PendingAuthSession](ctx, _q, qr, _q.inters)
+}
+
+// AllX is like All, but panics if an error occurs.
+func (_q *PendingAuthSessionQuery) AllX(ctx context.Context) []*PendingAuthSession {
+ nodes, err := _q.All(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return nodes
+}
+
+// IDs executes the query and returns a list of PendingAuthSession IDs.
+func (_q *PendingAuthSessionQuery) IDs(ctx context.Context) (ids []int64, err error) {
+ if _q.ctx.Unique == nil && _q.path != nil {
+ _q.Unique(true)
+ }
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
+ if err = _q.Select(pendingauthsession.FieldID).Scan(ctx, &ids); err != nil {
+ return nil, err
+ }
+ return ids, nil
+}
+
+// IDsX is like IDs, but panics if an error occurs.
+func (_q *PendingAuthSessionQuery) IDsX(ctx context.Context) []int64 {
+ ids, err := _q.IDs(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return ids
+}
+
+// Count returns the count of the given query.
+func (_q *PendingAuthSessionQuery) Count(ctx context.Context) (int, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return 0, err
+ }
+ return withInterceptors[int](ctx, _q, querierCount[*PendingAuthSessionQuery](), _q.inters)
+}
+
+// CountX is like Count, but panics if an error occurs.
+func (_q *PendingAuthSessionQuery) CountX(ctx context.Context) int {
+ count, err := _q.Count(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return count
+}
+
+// Exist returns true if the query has elements in the graph.
+func (_q *PendingAuthSessionQuery) Exist(ctx context.Context) (bool, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
+ switch _, err := _q.FirstID(ctx); {
+ case IsNotFound(err):
+ return false, nil
+ case err != nil:
+ return false, fmt.Errorf("ent: check existence: %w", err)
+ default:
+ return true, nil
+ }
+}
+
+// ExistX is like Exist, but panics if an error occurs.
+func (_q *PendingAuthSessionQuery) ExistX(ctx context.Context) bool {
+ exist, err := _q.Exist(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return exist
+}
+
+// Clone returns a duplicate of the PendingAuthSessionQuery builder, including all associated steps. It can be
+// used to prepare common query builders and use them differently after the clone is made.
+func (_q *PendingAuthSessionQuery) Clone() *PendingAuthSessionQuery {
+ if _q == nil {
+ return nil
+ }
+ return &PendingAuthSessionQuery{
+ config: _q.config,
+ ctx: _q.ctx.Clone(),
+ order: append([]pendingauthsession.OrderOption{}, _q.order...),
+ inters: append([]Interceptor{}, _q.inters...),
+ predicates: append([]predicate.PendingAuthSession{}, _q.predicates...),
+ withTargetUser: _q.withTargetUser.Clone(),
+ withAdoptionDecision: _q.withAdoptionDecision.Clone(),
+ // clone intermediate query.
+ sql: _q.sql.Clone(),
+ path: _q.path,
+ }
+}
+
+// WithTargetUser tells the query-builder to eager-load the nodes that are connected to
+// the "target_user" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *PendingAuthSessionQuery) WithTargetUser(opts ...func(*UserQuery)) *PendingAuthSessionQuery {
+ query := (&UserClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withTargetUser = query
+ return _q
+}
+
+// WithAdoptionDecision tells the query-builder to eager-load the nodes that are connected to
+// the "adoption_decision" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *PendingAuthSessionQuery) WithAdoptionDecision(opts ...func(*IdentityAdoptionDecisionQuery)) *PendingAuthSessionQuery {
+ query := (&IdentityAdoptionDecisionClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withAdoptionDecision = query
+ return _q
+}
+
+// GroupBy is used to group vertices by one or more fields/columns.
+// It is often used with aggregate functions, like: count, max, mean, min, sum.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// Count int `json:"count,omitempty"`
+// }
+//
+// client.PendingAuthSession.Query().
+// GroupBy(pendingauthsession.FieldCreatedAt).
+// Aggregate(ent.Count()).
+// Scan(ctx, &v)
+func (_q *PendingAuthSessionQuery) GroupBy(field string, fields ...string) *PendingAuthSessionGroupBy {
+ _q.ctx.Fields = append([]string{field}, fields...)
+ grbuild := &PendingAuthSessionGroupBy{build: _q}
+ grbuild.flds = &_q.ctx.Fields
+ grbuild.label = pendingauthsession.Label
+ grbuild.scan = grbuild.Scan
+ return grbuild
+}
+
+// Select allows the selection one or more fields/columns for the given query,
+// instead of selecting all fields in the entity.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// }
+//
+// client.PendingAuthSession.Query().
+// Select(pendingauthsession.FieldCreatedAt).
+// Scan(ctx, &v)
+func (_q *PendingAuthSessionQuery) Select(fields ...string) *PendingAuthSessionSelect {
+ _q.ctx.Fields = append(_q.ctx.Fields, fields...)
+ sbuild := &PendingAuthSessionSelect{PendingAuthSessionQuery: _q}
+ sbuild.label = pendingauthsession.Label
+ sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
+ return sbuild
+}
+
+// Aggregate returns a PendingAuthSessionSelect configured with the given aggregations.
+func (_q *PendingAuthSessionQuery) Aggregate(fns ...AggregateFunc) *PendingAuthSessionSelect {
+ return _q.Select().Aggregate(fns...)
+}
+
+func (_q *PendingAuthSessionQuery) prepareQuery(ctx context.Context) error {
+ for _, inter := range _q.inters {
+ if inter == nil {
+ return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
+ }
+ if trv, ok := inter.(Traverser); ok {
+ if err := trv.Traverse(ctx, _q); err != nil {
+ return err
+ }
+ }
+ }
+ for _, f := range _q.ctx.Fields {
+ if !pendingauthsession.ValidColumn(f) {
+ return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ }
+ if _q.path != nil {
+ prev, err := _q.path(ctx)
+ if err != nil {
+ return err
+ }
+ _q.sql = prev
+ }
+ return nil
+}
+
+func (_q *PendingAuthSessionQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*PendingAuthSession, error) {
+ var (
+ nodes = []*PendingAuthSession{}
+ _spec = _q.querySpec()
+ loadedTypes = [2]bool{
+ _q.withTargetUser != nil,
+ _q.withAdoptionDecision != nil,
+ }
+ )
+ _spec.ScanValues = func(columns []string) ([]any, error) {
+ return (*PendingAuthSession).scanValues(nil, columns)
+ }
+ _spec.Assign = func(columns []string, values []any) error {
+ node := &PendingAuthSession{config: _q.config}
+ nodes = append(nodes, node)
+ node.Edges.loadedTypes = loadedTypes
+ return node.assignValues(columns, values)
+ }
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ for i := range hooks {
+ hooks[i](ctx, _spec)
+ }
+ if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nodes, nil
+ }
+ if query := _q.withTargetUser; query != nil {
+ if err := _q.loadTargetUser(ctx, query, nodes, nil,
+ func(n *PendingAuthSession, e *User) { n.Edges.TargetUser = e }); err != nil {
+ return nil, err
+ }
+ }
+ if query := _q.withAdoptionDecision; query != nil {
+ if err := _q.loadAdoptionDecision(ctx, query, nodes, nil,
+ func(n *PendingAuthSession, e *IdentityAdoptionDecision) { n.Edges.AdoptionDecision = e }); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+func (_q *PendingAuthSessionQuery) loadTargetUser(ctx context.Context, query *UserQuery, nodes []*PendingAuthSession, init func(*PendingAuthSession), assign func(*PendingAuthSession, *User)) error {
+ ids := make([]int64, 0, len(nodes))
+ nodeids := make(map[int64][]*PendingAuthSession)
+ for i := range nodes {
+ if nodes[i].TargetUserID == nil {
+ continue
+ }
+ fk := *nodes[i].TargetUserID
+ if _, ok := nodeids[fk]; !ok {
+ ids = append(ids, fk)
+ }
+ nodeids[fk] = append(nodeids[fk], nodes[i])
+ }
+ if len(ids) == 0 {
+ return nil
+ }
+ query.Where(user.IDIn(ids...))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ nodes, ok := nodeids[n.ID]
+ if !ok {
+ return fmt.Errorf(`unexpected foreign-key "target_user_id" returned %v`, n.ID)
+ }
+ for i := range nodes {
+ assign(nodes[i], n)
+ }
+ }
+ return nil
+}
+func (_q *PendingAuthSessionQuery) loadAdoptionDecision(ctx context.Context, query *IdentityAdoptionDecisionQuery, nodes []*PendingAuthSession, init func(*PendingAuthSession), assign func(*PendingAuthSession, *IdentityAdoptionDecision)) error {
+ fks := make([]driver.Value, 0, len(nodes))
+ nodeids := make(map[int64]*PendingAuthSession)
+ for i := range nodes {
+ fks = append(fks, nodes[i].ID)
+ nodeids[nodes[i].ID] = nodes[i]
+ }
+ if len(query.ctx.Fields) > 0 {
+ query.ctx.AppendFieldOnce(identityadoptiondecision.FieldPendingAuthSessionID)
+ }
+ query.Where(predicate.IdentityAdoptionDecision(func(s *sql.Selector) {
+ s.Where(sql.InValues(s.C(pendingauthsession.AdoptionDecisionColumn), fks...))
+ }))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ fk := n.PendingAuthSessionID
+ node, ok := nodeids[fk]
+ if !ok {
+ return fmt.Errorf(`unexpected referenced foreign-key "pending_auth_session_id" returned %v for node %v`, fk, n.ID)
+ }
+ assign(node, n)
+ }
+ return nil
+}
+
+func (_q *PendingAuthSessionQuery) sqlCount(ctx context.Context) (int, error) {
+ _spec := _q.querySpec()
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ _spec.Node.Columns = _q.ctx.Fields
+ if len(_q.ctx.Fields) > 0 {
+ _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
+ }
+ return sqlgraph.CountNodes(ctx, _q.driver, _spec)
+}
+
+func (_q *PendingAuthSessionQuery) querySpec() *sqlgraph.QuerySpec {
+ _spec := sqlgraph.NewQuerySpec(pendingauthsession.Table, pendingauthsession.Columns, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64))
+ _spec.From = _q.sql
+ if unique := _q.ctx.Unique; unique != nil {
+ _spec.Unique = *unique
+ } else if _q.path != nil {
+ _spec.Unique = true
+ }
+ if fields := _q.ctx.Fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, pendingauthsession.FieldID)
+ for i := range fields {
+ if fields[i] != pendingauthsession.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, fields[i])
+ }
+ }
+ if _q.withTargetUser != nil {
+ _spec.Node.AddColumnOnce(pendingauthsession.FieldTargetUserID)
+ }
+ }
+ if ps := _q.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ _spec.Limit = *limit
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ _spec.Offset = *offset
+ }
+ if ps := _q.order; len(ps) > 0 {
+ _spec.Order = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ return _spec
+}
+
+func (_q *PendingAuthSessionQuery) sqlQuery(ctx context.Context) *sql.Selector {
+ builder := sql.Dialect(_q.driver.Dialect())
+ t1 := builder.Table(pendingauthsession.Table)
+ columns := _q.ctx.Fields
+ if len(columns) == 0 {
+ columns = pendingauthsession.Columns
+ }
+ selector := builder.Select(t1.Columns(columns...)...).From(t1)
+ if _q.sql != nil {
+ selector = _q.sql
+ selector.Select(selector.Columns(columns...)...)
+ }
+ if _q.ctx.Unique != nil && *_q.ctx.Unique {
+ selector.Distinct()
+ }
+ for _, m := range _q.modifiers {
+ m(selector)
+ }
+ for _, p := range _q.predicates {
+ p(selector)
+ }
+ for _, p := range _q.order {
+ p(selector)
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ // limit is mandatory for offset clause. We start
+ // with default value, and override it below if needed.
+ selector.Offset(*offset).Limit(math.MaxInt32)
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ selector.Limit(*limit)
+ }
+ return selector
+}
+
+// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
+// updated, deleted or "selected ... for update" by other sessions, until the transaction is
+// either committed or rolled-back.
+func (_q *PendingAuthSessionQuery) ForUpdate(opts ...sql.LockOption) *PendingAuthSessionQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForUpdate(opts...)
+ })
+ return _q
+}
+
+// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
+// on any rows that are read. Other sessions can read the rows, but cannot modify them
+// until your transaction commits.
+func (_q *PendingAuthSessionQuery) ForShare(opts ...sql.LockOption) *PendingAuthSessionQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForShare(opts...)
+ })
+ return _q
+}
+
+// PendingAuthSessionGroupBy is the group-by builder for PendingAuthSession entities.
+type PendingAuthSessionGroupBy struct {
+ selector
+ build *PendingAuthSessionQuery
+}
+
+// Aggregate adds the given aggregation functions to the group-by query.
+func (_g *PendingAuthSessionGroupBy) Aggregate(fns ...AggregateFunc) *PendingAuthSessionGroupBy {
+ _g.fns = append(_g.fns, fns...)
+ return _g
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_g *PendingAuthSessionGroupBy) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
+ if err := _g.build.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*PendingAuthSessionQuery, *PendingAuthSessionGroupBy](ctx, _g.build, _g, _g.build.inters, v)
+}
+
+func (_g *PendingAuthSessionGroupBy) sqlScan(ctx context.Context, root *PendingAuthSessionQuery, v any) error {
+ selector := root.sqlQuery(ctx).Select()
+ aggregation := make([]string, 0, len(_g.fns))
+ for _, fn := range _g.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ if len(selector.SelectedColumns()) == 0 {
+ columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
+ for _, f := range *_g.flds {
+ columns = append(columns, selector.C(f))
+ }
+ columns = append(columns, aggregation...)
+ selector.Select(columns...)
+ }
+ selector.GroupBy(selector.Columns(*_g.flds...)...)
+ if err := selector.Err(); err != nil {
+ return err
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
+
+// PendingAuthSessionSelect is the builder for selecting fields of PendingAuthSession entities.
+type PendingAuthSessionSelect struct {
+ *PendingAuthSessionQuery
+ selector
+}
+
+// Aggregate adds the given aggregation functions to the selector query.
+func (_s *PendingAuthSessionSelect) Aggregate(fns ...AggregateFunc) *PendingAuthSessionSelect {
+ _s.fns = append(_s.fns, fns...)
+ return _s
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_s *PendingAuthSessionSelect) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
+ if err := _s.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*PendingAuthSessionQuery, *PendingAuthSessionSelect](ctx, _s.PendingAuthSessionQuery, _s, _s.inters, v)
+}
+
+func (_s *PendingAuthSessionSelect) sqlScan(ctx context.Context, root *PendingAuthSessionQuery, v any) error {
+ selector := root.sqlQuery(ctx)
+ aggregation := make([]string, 0, len(_s.fns))
+ for _, fn := range _s.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ switch n := len(*_s.selector.flds); {
+ case n == 0 && len(aggregation) > 0:
+ selector.Select(aggregation...)
+ case n != 0 && len(aggregation) > 0:
+ selector.AppendSelect(aggregation...)
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _s.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
diff --git a/backend/ent/pendingauthsession_update.go b/backend/ent/pendingauthsession_update.go
new file mode 100644
index 00000000..00066f69
--- /dev/null
+++ b/backend/ent/pendingauthsession_update.go
@@ -0,0 +1,1178 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// PendingAuthSessionUpdate is the builder for updating PendingAuthSession entities.
+type PendingAuthSessionUpdate struct {
+ config
+ hooks []Hook
+ mutation *PendingAuthSessionMutation
+}
+
+// Where appends a list predicates to the PendingAuthSessionUpdate builder.
+func (_u *PendingAuthSessionUpdate) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionUpdate {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *PendingAuthSessionUpdate) SetUpdatedAt(v time.Time) *PendingAuthSessionUpdate {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetSessionToken sets the "session_token" field.
+func (_u *PendingAuthSessionUpdate) SetSessionToken(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetSessionToken(v)
+ return _u
+}
+
+// SetNillableSessionToken sets the "session_token" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableSessionToken(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetSessionToken(*v)
+ }
+ return _u
+}
+
+// SetIntent sets the "intent" field.
+func (_u *PendingAuthSessionUpdate) SetIntent(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetIntent(v)
+ return _u
+}
+
+// SetNillableIntent sets the "intent" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableIntent(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetIntent(*v)
+ }
+ return _u
+}
+
+// SetProviderType sets the "provider_type" field.
+func (_u *PendingAuthSessionUpdate) SetProviderType(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetProviderType(v)
+ return _u
+}
+
+// SetNillableProviderType sets the "provider_type" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableProviderType(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetProviderType(*v)
+ }
+ return _u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_u *PendingAuthSessionUpdate) SetProviderKey(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetProviderKey(v)
+ return _u
+}
+
+// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableProviderKey(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetProviderKey(*v)
+ }
+ return _u
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (_u *PendingAuthSessionUpdate) SetProviderSubject(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetProviderSubject(v)
+ return _u
+}
+
+// SetNillableProviderSubject sets the "provider_subject" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableProviderSubject(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetProviderSubject(*v)
+ }
+ return _u
+}
+
+// SetTargetUserID sets the "target_user_id" field.
+func (_u *PendingAuthSessionUpdate) SetTargetUserID(v int64) *PendingAuthSessionUpdate {
+ _u.mutation.SetTargetUserID(v)
+ return _u
+}
+
+// SetNillableTargetUserID sets the "target_user_id" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableTargetUserID(v *int64) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetTargetUserID(*v)
+ }
+ return _u
+}
+
+// ClearTargetUserID clears the value of the "target_user_id" field.
+func (_u *PendingAuthSessionUpdate) ClearTargetUserID() *PendingAuthSessionUpdate {
+ _u.mutation.ClearTargetUserID()
+ return _u
+}
+
+// SetRedirectTo sets the "redirect_to" field.
+func (_u *PendingAuthSessionUpdate) SetRedirectTo(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetRedirectTo(v)
+ return _u
+}
+
+// SetNillableRedirectTo sets the "redirect_to" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableRedirectTo(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetRedirectTo(*v)
+ }
+ return _u
+}
+
+// SetResolvedEmail sets the "resolved_email" field.
+func (_u *PendingAuthSessionUpdate) SetResolvedEmail(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetResolvedEmail(v)
+ return _u
+}
+
+// SetNillableResolvedEmail sets the "resolved_email" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableResolvedEmail(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetResolvedEmail(*v)
+ }
+ return _u
+}
+
+// SetRegistrationPasswordHash sets the "registration_password_hash" field.
+func (_u *PendingAuthSessionUpdate) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetRegistrationPasswordHash(v)
+ return _u
+}
+
+// SetNillableRegistrationPasswordHash sets the "registration_password_hash" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableRegistrationPasswordHash(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetRegistrationPasswordHash(*v)
+ }
+ return _u
+}
+
+// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field.
+func (_u *PendingAuthSessionUpdate) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpdate {
+ _u.mutation.SetUpstreamIdentityClaims(v)
+ return _u
+}
+
+// SetLocalFlowState sets the "local_flow_state" field.
+func (_u *PendingAuthSessionUpdate) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpdate {
+ _u.mutation.SetLocalFlowState(v)
+ return _u
+}
+
+// SetBrowserSessionKey sets the "browser_session_key" field.
+func (_u *PendingAuthSessionUpdate) SetBrowserSessionKey(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetBrowserSessionKey(v)
+ return _u
+}
+
+// SetNillableBrowserSessionKey sets the "browser_session_key" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableBrowserSessionKey(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetBrowserSessionKey(*v)
+ }
+ return _u
+}
+
+// SetCompletionCodeHash sets the "completion_code_hash" field.
+func (_u *PendingAuthSessionUpdate) SetCompletionCodeHash(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetCompletionCodeHash(v)
+ return _u
+}
+
+// SetNillableCompletionCodeHash sets the "completion_code_hash" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableCompletionCodeHash(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetCompletionCodeHash(*v)
+ }
+ return _u
+}
+
+// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field.
+func (_u *PendingAuthSessionUpdate) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpdate {
+ _u.mutation.SetCompletionCodeExpiresAt(v)
+ return _u
+}
+
+// SetNillableCompletionCodeExpiresAt sets the "completion_code_expires_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableCompletionCodeExpiresAt(v *time.Time) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetCompletionCodeExpiresAt(*v)
+ }
+ return _u
+}
+
+// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field.
+func (_u *PendingAuthSessionUpdate) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpdate {
+ _u.mutation.ClearCompletionCodeExpiresAt()
+ return _u
+}
+
+// SetEmailVerifiedAt sets the "email_verified_at" field.
+func (_u *PendingAuthSessionUpdate) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpdate {
+ _u.mutation.SetEmailVerifiedAt(v)
+ return _u
+}
+
+// SetNillableEmailVerifiedAt sets the "email_verified_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableEmailVerifiedAt(v *time.Time) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetEmailVerifiedAt(*v)
+ }
+ return _u
+}
+
+// ClearEmailVerifiedAt clears the value of the "email_verified_at" field.
+func (_u *PendingAuthSessionUpdate) ClearEmailVerifiedAt() *PendingAuthSessionUpdate {
+ _u.mutation.ClearEmailVerifiedAt()
+ return _u
+}
+
+// SetPasswordVerifiedAt sets the "password_verified_at" field.
+func (_u *PendingAuthSessionUpdate) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpdate {
+ _u.mutation.SetPasswordVerifiedAt(v)
+ return _u
+}
+
+// SetNillablePasswordVerifiedAt sets the "password_verified_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillablePasswordVerifiedAt(v *time.Time) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetPasswordVerifiedAt(*v)
+ }
+ return _u
+}
+
+// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field.
+func (_u *PendingAuthSessionUpdate) ClearPasswordVerifiedAt() *PendingAuthSessionUpdate {
+ _u.mutation.ClearPasswordVerifiedAt()
+ return _u
+}
+
+// SetTotpVerifiedAt sets the "totp_verified_at" field.
+func (_u *PendingAuthSessionUpdate) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpdate {
+ _u.mutation.SetTotpVerifiedAt(v)
+ return _u
+}
+
+// SetNillableTotpVerifiedAt sets the "totp_verified_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableTotpVerifiedAt(v *time.Time) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetTotpVerifiedAt(*v)
+ }
+ return _u
+}
+
+// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field.
+func (_u *PendingAuthSessionUpdate) ClearTotpVerifiedAt() *PendingAuthSessionUpdate {
+ _u.mutation.ClearTotpVerifiedAt()
+ return _u
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (_u *PendingAuthSessionUpdate) SetExpiresAt(v time.Time) *PendingAuthSessionUpdate {
+ _u.mutation.SetExpiresAt(v)
+ return _u
+}
+
+// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableExpiresAt(v *time.Time) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetExpiresAt(*v)
+ }
+ return _u
+}
+
+// SetConsumedAt sets the "consumed_at" field.
+func (_u *PendingAuthSessionUpdate) SetConsumedAt(v time.Time) *PendingAuthSessionUpdate {
+ _u.mutation.SetConsumedAt(v)
+ return _u
+}
+
+// SetNillableConsumedAt sets the "consumed_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableConsumedAt(v *time.Time) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetConsumedAt(*v)
+ }
+ return _u
+}
+
+// ClearConsumedAt clears the value of the "consumed_at" field.
+func (_u *PendingAuthSessionUpdate) ClearConsumedAt() *PendingAuthSessionUpdate {
+ _u.mutation.ClearConsumedAt()
+ return _u
+}
+
+// SetTargetUser sets the "target_user" edge to the User entity.
+func (_u *PendingAuthSessionUpdate) SetTargetUser(v *User) *PendingAuthSessionUpdate {
+ return _u.SetTargetUserID(v.ID)
+}
+
+// SetAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID.
+func (_u *PendingAuthSessionUpdate) SetAdoptionDecisionID(id int64) *PendingAuthSessionUpdate {
+ _u.mutation.SetAdoptionDecisionID(id)
+ return _u
+}
+
+// SetNillableAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableAdoptionDecisionID(id *int64) *PendingAuthSessionUpdate {
+ if id != nil {
+ _u = _u.SetAdoptionDecisionID(*id)
+ }
+ return _u
+}
+
+// SetAdoptionDecision sets the "adoption_decision" edge to the IdentityAdoptionDecision entity.
+func (_u *PendingAuthSessionUpdate) SetAdoptionDecision(v *IdentityAdoptionDecision) *PendingAuthSessionUpdate {
+ return _u.SetAdoptionDecisionID(v.ID)
+}
+
+// Mutation returns the PendingAuthSessionMutation object of the builder.
+func (_u *PendingAuthSessionUpdate) Mutation() *PendingAuthSessionMutation {
+ return _u.mutation
+}
+
+// ClearTargetUser clears the "target_user" edge to the User entity.
+func (_u *PendingAuthSessionUpdate) ClearTargetUser() *PendingAuthSessionUpdate {
+ _u.mutation.ClearTargetUser()
+ return _u
+}
+
+// ClearAdoptionDecision clears the "adoption_decision" edge to the IdentityAdoptionDecision entity.
+func (_u *PendingAuthSessionUpdate) ClearAdoptionDecision() *PendingAuthSessionUpdate {
+ _u.mutation.ClearAdoptionDecision()
+ return _u
+}
+
+// Save executes the query and returns the number of nodes affected by the update operation.
+func (_u *PendingAuthSessionUpdate) Save(ctx context.Context) (int, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *PendingAuthSessionUpdate) SaveX(ctx context.Context) int {
+ affected, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return affected
+}
+
+// Exec executes the query.
+func (_u *PendingAuthSessionUpdate) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *PendingAuthSessionUpdate) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *PendingAuthSessionUpdate) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := pendingauthsession.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *PendingAuthSessionUpdate) check() error {
+ if v, ok := _u.mutation.SessionToken(); ok {
+ if err := pendingauthsession.SessionTokenValidator(v); err != nil {
+ return &ValidationError{Name: "session_token", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.session_token": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Intent(); ok {
+ if err := pendingauthsession.IntentValidator(v); err != nil {
+ return &ValidationError{Name: "intent", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.intent": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderType(); ok {
+ if err := pendingauthsession.ProviderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_type": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderKey(); ok {
+ if err := pendingauthsession.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_key": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderSubject(); ok {
+ if err := pendingauthsession.ProviderSubjectValidator(v); err != nil {
+ return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_subject": %w`, err)}
+ }
+ }
+ return nil
+}
+
+func (_u *PendingAuthSessionUpdate) sqlSave(ctx context.Context) (_node int, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(pendingauthsession.Table, pendingauthsession.Columns, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64))
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.SessionToken(); ok {
+ _spec.SetField(pendingauthsession.FieldSessionToken, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Intent(); ok {
+ _spec.SetField(pendingauthsession.FieldIntent, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderType(); ok {
+ _spec.SetField(pendingauthsession.FieldProviderType, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderKey(); ok {
+ _spec.SetField(pendingauthsession.FieldProviderKey, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderSubject(); ok {
+ _spec.SetField(pendingauthsession.FieldProviderSubject, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.RedirectTo(); ok {
+ _spec.SetField(pendingauthsession.FieldRedirectTo, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ResolvedEmail(); ok {
+ _spec.SetField(pendingauthsession.FieldResolvedEmail, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.RegistrationPasswordHash(); ok {
+ _spec.SetField(pendingauthsession.FieldRegistrationPasswordHash, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.UpstreamIdentityClaims(); ok {
+ _spec.SetField(pendingauthsession.FieldUpstreamIdentityClaims, field.TypeJSON, value)
+ }
+ if value, ok := _u.mutation.LocalFlowState(); ok {
+ _spec.SetField(pendingauthsession.FieldLocalFlowState, field.TypeJSON, value)
+ }
+ if value, ok := _u.mutation.BrowserSessionKey(); ok {
+ _spec.SetField(pendingauthsession.FieldBrowserSessionKey, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.CompletionCodeHash(); ok {
+ _spec.SetField(pendingauthsession.FieldCompletionCodeHash, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.CompletionCodeExpiresAt(); ok {
+ _spec.SetField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime, value)
+ }
+ if _u.mutation.CompletionCodeExpiresAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.EmailVerifiedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime, value)
+ }
+ if _u.mutation.EmailVerifiedAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.PasswordVerifiedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime, value)
+ }
+ if _u.mutation.PasswordVerifiedAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.TotpVerifiedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime, value)
+ }
+ if _u.mutation.TotpVerifiedAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.ExpiresAt(); ok {
+ _spec.SetField(pendingauthsession.FieldExpiresAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.ConsumedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldConsumedAt, field.TypeTime, value)
+ }
+ if _u.mutation.ConsumedAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldConsumedAt, field.TypeTime)
+ }
+ if _u.mutation.TargetUserCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: pendingauthsession.TargetUserTable,
+ Columns: []string{pendingauthsession.TargetUserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.TargetUserIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: pendingauthsession.TargetUserTable,
+ Columns: []string{pendingauthsession.TargetUserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.AdoptionDecisionCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: false,
+ Table: pendingauthsession.AdoptionDecisionTable,
+ Columns: []string{pendingauthsession.AdoptionDecisionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.AdoptionDecisionIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: false,
+ Table: pendingauthsession.AdoptionDecisionTable,
+ Columns: []string{pendingauthsession.AdoptionDecisionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{pendingauthsession.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return 0, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
+
+// PendingAuthSessionUpdateOne is the builder for updating a single PendingAuthSession entity.
+type PendingAuthSessionUpdateOne struct {
+ config
+ fields []string
+ hooks []Hook
+ mutation *PendingAuthSessionMutation
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *PendingAuthSessionUpdateOne) SetUpdatedAt(v time.Time) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetSessionToken sets the "session_token" field.
+func (_u *PendingAuthSessionUpdateOne) SetSessionToken(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetSessionToken(v)
+ return _u
+}
+
+// SetNillableSessionToken sets the "session_token" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableSessionToken(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetSessionToken(*v)
+ }
+ return _u
+}
+
+// SetIntent sets the "intent" field.
+func (_u *PendingAuthSessionUpdateOne) SetIntent(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetIntent(v)
+ return _u
+}
+
+// SetNillableIntent sets the "intent" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableIntent(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetIntent(*v)
+ }
+ return _u
+}
+
+// SetProviderType sets the "provider_type" field.
+func (_u *PendingAuthSessionUpdateOne) SetProviderType(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetProviderType(v)
+ return _u
+}
+
+// SetNillableProviderType sets the "provider_type" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableProviderType(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetProviderType(*v)
+ }
+ return _u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_u *PendingAuthSessionUpdateOne) SetProviderKey(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetProviderKey(v)
+ return _u
+}
+
+// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableProviderKey(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetProviderKey(*v)
+ }
+ return _u
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (_u *PendingAuthSessionUpdateOne) SetProviderSubject(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetProviderSubject(v)
+ return _u
+}
+
+// SetNillableProviderSubject sets the "provider_subject" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableProviderSubject(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetProviderSubject(*v)
+ }
+ return _u
+}
+
+// SetTargetUserID sets the "target_user_id" field.
+func (_u *PendingAuthSessionUpdateOne) SetTargetUserID(v int64) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetTargetUserID(v)
+ return _u
+}
+
+// SetNillableTargetUserID sets the "target_user_id" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableTargetUserID(v *int64) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetTargetUserID(*v)
+ }
+ return _u
+}
+
+// ClearTargetUserID clears the value of the "target_user_id" field.
+func (_u *PendingAuthSessionUpdateOne) ClearTargetUserID() *PendingAuthSessionUpdateOne {
+ _u.mutation.ClearTargetUserID()
+ return _u
+}
+
+// SetRedirectTo sets the "redirect_to" field.
+func (_u *PendingAuthSessionUpdateOne) SetRedirectTo(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetRedirectTo(v)
+ return _u
+}
+
+// SetNillableRedirectTo sets the "redirect_to" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableRedirectTo(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetRedirectTo(*v)
+ }
+ return _u
+}
+
+// SetResolvedEmail sets the "resolved_email" field.
+func (_u *PendingAuthSessionUpdateOne) SetResolvedEmail(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetResolvedEmail(v)
+ return _u
+}
+
+// SetNillableResolvedEmail sets the "resolved_email" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableResolvedEmail(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetResolvedEmail(*v)
+ }
+ return _u
+}
+
+// SetRegistrationPasswordHash sets the "registration_password_hash" field.
+func (_u *PendingAuthSessionUpdateOne) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetRegistrationPasswordHash(v)
+ return _u
+}
+
+// SetNillableRegistrationPasswordHash sets the "registration_password_hash" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableRegistrationPasswordHash(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetRegistrationPasswordHash(*v)
+ }
+ return _u
+}
+
+// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field.
+func (_u *PendingAuthSessionUpdateOne) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetUpstreamIdentityClaims(v)
+ return _u
+}
+
+// SetLocalFlowState sets the "local_flow_state" field.
+func (_u *PendingAuthSessionUpdateOne) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetLocalFlowState(v)
+ return _u
+}
+
+// SetBrowserSessionKey sets the "browser_session_key" field.
+func (_u *PendingAuthSessionUpdateOne) SetBrowserSessionKey(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetBrowserSessionKey(v)
+ return _u
+}
+
+// SetNillableBrowserSessionKey sets the "browser_session_key" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableBrowserSessionKey(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetBrowserSessionKey(*v)
+ }
+ return _u
+}
+
+// SetCompletionCodeHash sets the "completion_code_hash" field.
+func (_u *PendingAuthSessionUpdateOne) SetCompletionCodeHash(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetCompletionCodeHash(v)
+ return _u
+}
+
+// SetNillableCompletionCodeHash sets the "completion_code_hash" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableCompletionCodeHash(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetCompletionCodeHash(*v)
+ }
+ return _u
+}
+
+// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field.
+func (_u *PendingAuthSessionUpdateOne) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetCompletionCodeExpiresAt(v)
+ return _u
+}
+
+// SetNillableCompletionCodeExpiresAt sets the "completion_code_expires_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableCompletionCodeExpiresAt(v *time.Time) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetCompletionCodeExpiresAt(*v)
+ }
+ return _u
+}
+
+// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field.
+func (_u *PendingAuthSessionUpdateOne) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpdateOne {
+ _u.mutation.ClearCompletionCodeExpiresAt()
+ return _u
+}
+
+// SetEmailVerifiedAt sets the "email_verified_at" field.
+func (_u *PendingAuthSessionUpdateOne) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetEmailVerifiedAt(v)
+ return _u
+}
+
+// SetNillableEmailVerifiedAt sets the "email_verified_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableEmailVerifiedAt(v *time.Time) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetEmailVerifiedAt(*v)
+ }
+ return _u
+}
+
+// ClearEmailVerifiedAt clears the value of the "email_verified_at" field.
+func (_u *PendingAuthSessionUpdateOne) ClearEmailVerifiedAt() *PendingAuthSessionUpdateOne {
+ _u.mutation.ClearEmailVerifiedAt()
+ return _u
+}
+
+// SetPasswordVerifiedAt sets the "password_verified_at" field.
+func (_u *PendingAuthSessionUpdateOne) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetPasswordVerifiedAt(v)
+ return _u
+}
+
+// SetNillablePasswordVerifiedAt sets the "password_verified_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillablePasswordVerifiedAt(v *time.Time) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetPasswordVerifiedAt(*v)
+ }
+ return _u
+}
+
+// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field.
+func (_u *PendingAuthSessionUpdateOne) ClearPasswordVerifiedAt() *PendingAuthSessionUpdateOne {
+ _u.mutation.ClearPasswordVerifiedAt()
+ return _u
+}
+
+// SetTotpVerifiedAt sets the "totp_verified_at" field.
+func (_u *PendingAuthSessionUpdateOne) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetTotpVerifiedAt(v)
+ return _u
+}
+
+// SetNillableTotpVerifiedAt sets the "totp_verified_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableTotpVerifiedAt(v *time.Time) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetTotpVerifiedAt(*v)
+ }
+ return _u
+}
+
+// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field.
+func (_u *PendingAuthSessionUpdateOne) ClearTotpVerifiedAt() *PendingAuthSessionUpdateOne {
+ _u.mutation.ClearTotpVerifiedAt()
+ return _u
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (_u *PendingAuthSessionUpdateOne) SetExpiresAt(v time.Time) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetExpiresAt(v)
+ return _u
+}
+
+// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableExpiresAt(v *time.Time) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetExpiresAt(*v)
+ }
+ return _u
+}
+
+// SetConsumedAt sets the "consumed_at" field.
+func (_u *PendingAuthSessionUpdateOne) SetConsumedAt(v time.Time) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetConsumedAt(v)
+ return _u
+}
+
+// SetNillableConsumedAt sets the "consumed_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableConsumedAt(v *time.Time) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetConsumedAt(*v)
+ }
+ return _u
+}
+
+// ClearConsumedAt clears the value of the "consumed_at" field.
+func (_u *PendingAuthSessionUpdateOne) ClearConsumedAt() *PendingAuthSessionUpdateOne {
+ _u.mutation.ClearConsumedAt()
+ return _u
+}
+
+// SetTargetUser sets the "target_user" edge to the User entity.
+func (_u *PendingAuthSessionUpdateOne) SetTargetUser(v *User) *PendingAuthSessionUpdateOne {
+ return _u.SetTargetUserID(v.ID)
+}
+
+// SetAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID.
+func (_u *PendingAuthSessionUpdateOne) SetAdoptionDecisionID(id int64) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetAdoptionDecisionID(id)
+ return _u
+}
+
+// SetNillableAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableAdoptionDecisionID(id *int64) *PendingAuthSessionUpdateOne {
+ if id != nil {
+ _u = _u.SetAdoptionDecisionID(*id)
+ }
+ return _u
+}
+
+// SetAdoptionDecision sets the "adoption_decision" edge to the IdentityAdoptionDecision entity.
+func (_u *PendingAuthSessionUpdateOne) SetAdoptionDecision(v *IdentityAdoptionDecision) *PendingAuthSessionUpdateOne {
+ return _u.SetAdoptionDecisionID(v.ID)
+}
+
+// Mutation returns the PendingAuthSessionMutation object of the builder.
+func (_u *PendingAuthSessionUpdateOne) Mutation() *PendingAuthSessionMutation {
+ return _u.mutation
+}
+
+// ClearTargetUser clears the "target_user" edge to the User entity.
+func (_u *PendingAuthSessionUpdateOne) ClearTargetUser() *PendingAuthSessionUpdateOne {
+ _u.mutation.ClearTargetUser()
+ return _u
+}
+
+// ClearAdoptionDecision clears the "adoption_decision" edge to the IdentityAdoptionDecision entity.
+func (_u *PendingAuthSessionUpdateOne) ClearAdoptionDecision() *PendingAuthSessionUpdateOne {
+ _u.mutation.ClearAdoptionDecision()
+ return _u
+}
+
+// Where appends a list predicates to the PendingAuthSessionUpdate builder.
+func (_u *PendingAuthSessionUpdateOne) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionUpdateOne {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// Select allows selecting one or more fields (columns) of the returned entity.
+// The default is selecting all fields defined in the entity schema.
+func (_u *PendingAuthSessionUpdateOne) Select(field string, fields ...string) *PendingAuthSessionUpdateOne {
+ _u.fields = append([]string{field}, fields...)
+ return _u
+}
+
+// Save executes the query and returns the updated PendingAuthSession entity.
+func (_u *PendingAuthSessionUpdateOne) Save(ctx context.Context) (*PendingAuthSession, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *PendingAuthSessionUpdateOne) SaveX(ctx context.Context) *PendingAuthSession {
+ node, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// Exec executes the query on the entity.
+func (_u *PendingAuthSessionUpdateOne) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *PendingAuthSessionUpdateOne) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *PendingAuthSessionUpdateOne) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := pendingauthsession.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *PendingAuthSessionUpdateOne) check() error {
+ if v, ok := _u.mutation.SessionToken(); ok {
+ if err := pendingauthsession.SessionTokenValidator(v); err != nil {
+ return &ValidationError{Name: "session_token", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.session_token": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Intent(); ok {
+ if err := pendingauthsession.IntentValidator(v); err != nil {
+ return &ValidationError{Name: "intent", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.intent": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderType(); ok {
+ if err := pendingauthsession.ProviderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_type": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderKey(); ok {
+ if err := pendingauthsession.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_key": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderSubject(); ok {
+ if err := pendingauthsession.ProviderSubjectValidator(v); err != nil {
+ return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_subject": %w`, err)}
+ }
+ }
+ return nil
+}
+
+func (_u *PendingAuthSessionUpdateOne) sqlSave(ctx context.Context) (_node *PendingAuthSession, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(pendingauthsession.Table, pendingauthsession.Columns, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64))
+ id, ok := _u.mutation.ID()
+ if !ok {
+ return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "PendingAuthSession.id" for update`)}
+ }
+ _spec.Node.ID.Value = id
+ if fields := _u.fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, pendingauthsession.FieldID)
+ for _, f := range fields {
+ if !pendingauthsession.ValidColumn(f) {
+ return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ if f != pendingauthsession.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, f)
+ }
+ }
+ }
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.SessionToken(); ok {
+ _spec.SetField(pendingauthsession.FieldSessionToken, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Intent(); ok {
+ _spec.SetField(pendingauthsession.FieldIntent, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderType(); ok {
+ _spec.SetField(pendingauthsession.FieldProviderType, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderKey(); ok {
+ _spec.SetField(pendingauthsession.FieldProviderKey, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderSubject(); ok {
+ _spec.SetField(pendingauthsession.FieldProviderSubject, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.RedirectTo(); ok {
+ _spec.SetField(pendingauthsession.FieldRedirectTo, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ResolvedEmail(); ok {
+ _spec.SetField(pendingauthsession.FieldResolvedEmail, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.RegistrationPasswordHash(); ok {
+ _spec.SetField(pendingauthsession.FieldRegistrationPasswordHash, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.UpstreamIdentityClaims(); ok {
+ _spec.SetField(pendingauthsession.FieldUpstreamIdentityClaims, field.TypeJSON, value)
+ }
+ if value, ok := _u.mutation.LocalFlowState(); ok {
+ _spec.SetField(pendingauthsession.FieldLocalFlowState, field.TypeJSON, value)
+ }
+ if value, ok := _u.mutation.BrowserSessionKey(); ok {
+ _spec.SetField(pendingauthsession.FieldBrowserSessionKey, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.CompletionCodeHash(); ok {
+ _spec.SetField(pendingauthsession.FieldCompletionCodeHash, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.CompletionCodeExpiresAt(); ok {
+ _spec.SetField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime, value)
+ }
+ if _u.mutation.CompletionCodeExpiresAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.EmailVerifiedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime, value)
+ }
+ if _u.mutation.EmailVerifiedAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.PasswordVerifiedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime, value)
+ }
+ if _u.mutation.PasswordVerifiedAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.TotpVerifiedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime, value)
+ }
+ if _u.mutation.TotpVerifiedAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.ExpiresAt(); ok {
+ _spec.SetField(pendingauthsession.FieldExpiresAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.ConsumedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldConsumedAt, field.TypeTime, value)
+ }
+ if _u.mutation.ConsumedAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldConsumedAt, field.TypeTime)
+ }
+ if _u.mutation.TargetUserCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: pendingauthsession.TargetUserTable,
+ Columns: []string{pendingauthsession.TargetUserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.TargetUserIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: pendingauthsession.TargetUserTable,
+ Columns: []string{pendingauthsession.TargetUserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.AdoptionDecisionCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: false,
+ Table: pendingauthsession.AdoptionDecisionTable,
+ Columns: []string{pendingauthsession.AdoptionDecisionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.AdoptionDecisionIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: false,
+ Table: pendingauthsession.AdoptionDecisionTable,
+ Columns: []string{pendingauthsession.AdoptionDecisionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ _node = &PendingAuthSession{config: _u.config}
+ _spec.Assign = _node.assignValues
+ _spec.ScanValues = _node.scanValues
+ if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{pendingauthsession.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
diff --git a/backend/ent/predicate/predicate.go b/backend/ent/predicate/predicate.go
index ef551940..0aa90b90 100644
--- a/backend/ent/predicate/predicate.go
+++ b/backend/ent/predicate/predicate.go
@@ -21,6 +21,12 @@ type Announcement func(*sql.Selector)
// AnnouncementRead is the predicate function for announcementread builders.
type AnnouncementRead func(*sql.Selector)
+// AuthIdentity is the predicate function for authidentity builders.
+type AuthIdentity func(*sql.Selector)
+
+// AuthIdentityChannel is the predicate function for authidentitychannel builders.
+type AuthIdentityChannel func(*sql.Selector)
+
// ErrorPassthroughRule is the predicate function for errorpassthroughrule builders.
type ErrorPassthroughRule func(*sql.Selector)
@@ -30,6 +36,9 @@ type Group func(*sql.Selector)
// IdempotencyRecord is the predicate function for idempotencyrecord builders.
type IdempotencyRecord func(*sql.Selector)
+// IdentityAdoptionDecision is the predicate function for identityadoptiondecision builders.
+type IdentityAdoptionDecision func(*sql.Selector)
+
// PaymentAuditLog is the predicate function for paymentauditlog builders.
type PaymentAuditLog func(*sql.Selector)
@@ -39,6 +48,9 @@ type PaymentOrder func(*sql.Selector)
// PaymentProviderInstance is the predicate function for paymentproviderinstance builders.
type PaymentProviderInstance func(*sql.Selector)
+// PendingAuthSession is the predicate function for pendingauthsession builders.
+type PendingAuthSession func(*sql.Selector)
+
// PromoCode is the predicate function for promocode builders.
type PromoCode func(*sql.Selector)
diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go
index fbdd08c7..268e9ddb 100644
--- a/backend/ent/runtime/runtime.go
+++ b/backend/ent/runtime/runtime.go
@@ -10,12 +10,16 @@ import (
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/proxy"
@@ -309,6 +313,120 @@ func init() {
announcementreadDescCreatedAt := announcementreadFields[3].Descriptor()
// announcementread.DefaultCreatedAt holds the default value on creation for the created_at field.
announcementread.DefaultCreatedAt = announcementreadDescCreatedAt.Default.(func() time.Time)
+ authidentityMixin := schema.AuthIdentity{}.Mixin()
+ authidentityMixinFields0 := authidentityMixin[0].Fields()
+ _ = authidentityMixinFields0
+ authidentityFields := schema.AuthIdentity{}.Fields()
+ _ = authidentityFields
+ // authidentityDescCreatedAt is the schema descriptor for created_at field.
+ authidentityDescCreatedAt := authidentityMixinFields0[0].Descriptor()
+ // authidentity.DefaultCreatedAt holds the default value on creation for the created_at field.
+ authidentity.DefaultCreatedAt = authidentityDescCreatedAt.Default.(func() time.Time)
+ // authidentityDescUpdatedAt is the schema descriptor for updated_at field.
+ authidentityDescUpdatedAt := authidentityMixinFields0[1].Descriptor()
+ // authidentity.DefaultUpdatedAt holds the default value on creation for the updated_at field.
+ authidentity.DefaultUpdatedAt = authidentityDescUpdatedAt.Default.(func() time.Time)
+ // authidentity.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
+ authidentity.UpdateDefaultUpdatedAt = authidentityDescUpdatedAt.UpdateDefault.(func() time.Time)
+ // authidentityDescProviderType is the schema descriptor for provider_type field.
+ authidentityDescProviderType := authidentityFields[1].Descriptor()
+ // authidentity.ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save.
+ authidentity.ProviderTypeValidator = func() func(string) error {
+ validators := authidentityDescProviderType.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ validators[2].(func(string) error),
+ }
+ return func(provider_type string) error {
+ for _, fn := range fns {
+ if err := fn(provider_type); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // authidentityDescProviderKey is the schema descriptor for provider_key field.
+ authidentityDescProviderKey := authidentityFields[2].Descriptor()
+ // authidentity.ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
+ authidentity.ProviderKeyValidator = authidentityDescProviderKey.Validators[0].(func(string) error)
+ // authidentityDescProviderSubject is the schema descriptor for provider_subject field.
+ authidentityDescProviderSubject := authidentityFields[3].Descriptor()
+ // authidentity.ProviderSubjectValidator is a validator for the "provider_subject" field. It is called by the builders before save.
+ authidentity.ProviderSubjectValidator = authidentityDescProviderSubject.Validators[0].(func(string) error)
+ // authidentityDescMetadata is the schema descriptor for metadata field.
+ authidentityDescMetadata := authidentityFields[6].Descriptor()
+ // authidentity.DefaultMetadata holds the default value on creation for the metadata field.
+ authidentity.DefaultMetadata = authidentityDescMetadata.Default.(func() map[string]interface{})
+ authidentitychannelMixin := schema.AuthIdentityChannel{}.Mixin()
+ authidentitychannelMixinFields0 := authidentitychannelMixin[0].Fields()
+ _ = authidentitychannelMixinFields0
+ authidentitychannelFields := schema.AuthIdentityChannel{}.Fields()
+ _ = authidentitychannelFields
+ // authidentitychannelDescCreatedAt is the schema descriptor for created_at field.
+ authidentitychannelDescCreatedAt := authidentitychannelMixinFields0[0].Descriptor()
+ // authidentitychannel.DefaultCreatedAt holds the default value on creation for the created_at field.
+ authidentitychannel.DefaultCreatedAt = authidentitychannelDescCreatedAt.Default.(func() time.Time)
+ // authidentitychannelDescUpdatedAt is the schema descriptor for updated_at field.
+ authidentitychannelDescUpdatedAt := authidentitychannelMixinFields0[1].Descriptor()
+ // authidentitychannel.DefaultUpdatedAt holds the default value on creation for the updated_at field.
+ authidentitychannel.DefaultUpdatedAt = authidentitychannelDescUpdatedAt.Default.(func() time.Time)
+ // authidentitychannel.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
+ authidentitychannel.UpdateDefaultUpdatedAt = authidentitychannelDescUpdatedAt.UpdateDefault.(func() time.Time)
+ // authidentitychannelDescProviderType is the schema descriptor for provider_type field.
+ authidentitychannelDescProviderType := authidentitychannelFields[1].Descriptor()
+ // authidentitychannel.ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save.
+ authidentitychannel.ProviderTypeValidator = func() func(string) error {
+ validators := authidentitychannelDescProviderType.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ validators[2].(func(string) error),
+ }
+ return func(provider_type string) error {
+ for _, fn := range fns {
+ if err := fn(provider_type); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // authidentitychannelDescProviderKey is the schema descriptor for provider_key field.
+ authidentitychannelDescProviderKey := authidentitychannelFields[2].Descriptor()
+ // authidentitychannel.ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
+ authidentitychannel.ProviderKeyValidator = authidentitychannelDescProviderKey.Validators[0].(func(string) error)
+ // authidentitychannelDescChannel is the schema descriptor for channel field.
+ authidentitychannelDescChannel := authidentitychannelFields[3].Descriptor()
+ // authidentitychannel.ChannelValidator is a validator for the "channel" field. It is called by the builders before save.
+ authidentitychannel.ChannelValidator = func() func(string) error {
+ validators := authidentitychannelDescChannel.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ }
+ return func(channel string) error {
+ for _, fn := range fns {
+ if err := fn(channel); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // authidentitychannelDescChannelAppID is the schema descriptor for channel_app_id field.
+ authidentitychannelDescChannelAppID := authidentitychannelFields[4].Descriptor()
+ // authidentitychannel.ChannelAppIDValidator is a validator for the "channel_app_id" field. It is called by the builders before save.
+ authidentitychannel.ChannelAppIDValidator = authidentitychannelDescChannelAppID.Validators[0].(func(string) error)
+ // authidentitychannelDescChannelSubject is the schema descriptor for channel_subject field.
+ authidentitychannelDescChannelSubject := authidentitychannelFields[5].Descriptor()
+ // authidentitychannel.ChannelSubjectValidator is a validator for the "channel_subject" field. It is called by the builders before save.
+ authidentitychannel.ChannelSubjectValidator = authidentitychannelDescChannelSubject.Validators[0].(func(string) error)
+ // authidentitychannelDescMetadata is the schema descriptor for metadata field.
+ authidentitychannelDescMetadata := authidentitychannelFields[6].Descriptor()
+ // authidentitychannel.DefaultMetadata holds the default value on creation for the metadata field.
+ authidentitychannel.DefaultMetadata = authidentitychannelDescMetadata.Default.(func() map[string]interface{})
errorpassthroughruleMixin := schema.ErrorPassthroughRule{}.Mixin()
errorpassthroughruleMixinFields0 := errorpassthroughruleMixin[0].Fields()
_ = errorpassthroughruleMixinFields0
@@ -512,6 +630,33 @@ func init() {
idempotencyrecordDescErrorReason := idempotencyrecordFields[6].Descriptor()
// idempotencyrecord.ErrorReasonValidator is a validator for the "error_reason" field. It is called by the builders before save.
idempotencyrecord.ErrorReasonValidator = idempotencyrecordDescErrorReason.Validators[0].(func(string) error)
+ identityadoptiondecisionMixin := schema.IdentityAdoptionDecision{}.Mixin()
+ identityadoptiondecisionMixinFields0 := identityadoptiondecisionMixin[0].Fields()
+ _ = identityadoptiondecisionMixinFields0
+ identityadoptiondecisionFields := schema.IdentityAdoptionDecision{}.Fields()
+ _ = identityadoptiondecisionFields
+ // identityadoptiondecisionDescCreatedAt is the schema descriptor for created_at field.
+ identityadoptiondecisionDescCreatedAt := identityadoptiondecisionMixinFields0[0].Descriptor()
+ // identityadoptiondecision.DefaultCreatedAt holds the default value on creation for the created_at field.
+ identityadoptiondecision.DefaultCreatedAt = identityadoptiondecisionDescCreatedAt.Default.(func() time.Time)
+ // identityadoptiondecisionDescUpdatedAt is the schema descriptor for updated_at field.
+ identityadoptiondecisionDescUpdatedAt := identityadoptiondecisionMixinFields0[1].Descriptor()
+ // identityadoptiondecision.DefaultUpdatedAt holds the default value on creation for the updated_at field.
+ identityadoptiondecision.DefaultUpdatedAt = identityadoptiondecisionDescUpdatedAt.Default.(func() time.Time)
+ // identityadoptiondecision.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
+ identityadoptiondecision.UpdateDefaultUpdatedAt = identityadoptiondecisionDescUpdatedAt.UpdateDefault.(func() time.Time)
+ // identityadoptiondecisionDescAdoptDisplayName is the schema descriptor for adopt_display_name field.
+ identityadoptiondecisionDescAdoptDisplayName := identityadoptiondecisionFields[2].Descriptor()
+ // identityadoptiondecision.DefaultAdoptDisplayName holds the default value on creation for the adopt_display_name field.
+ identityadoptiondecision.DefaultAdoptDisplayName = identityadoptiondecisionDescAdoptDisplayName.Default.(bool)
+ // identityadoptiondecisionDescAdoptAvatar is the schema descriptor for adopt_avatar field.
+ identityadoptiondecisionDescAdoptAvatar := identityadoptiondecisionFields[3].Descriptor()
+ // identityadoptiondecision.DefaultAdoptAvatar holds the default value on creation for the adopt_avatar field.
+ identityadoptiondecision.DefaultAdoptAvatar = identityadoptiondecisionDescAdoptAvatar.Default.(bool)
+ // identityadoptiondecisionDescDecidedAt is the schema descriptor for decided_at field.
+ identityadoptiondecisionDescDecidedAt := identityadoptiondecisionFields[4].Descriptor()
+ // identityadoptiondecision.DefaultDecidedAt holds the default value on creation for the decided_at field.
+ identityadoptiondecision.DefaultDecidedAt = identityadoptiondecisionDescDecidedAt.Default.(func() time.Time)
paymentauditlogFields := schema.PaymentAuditLog{}.Fields()
_ = paymentauditlogFields
// paymentauditlogDescOrderID is the schema descriptor for order_id field.
@@ -682,6 +827,113 @@ func init() {
paymentproviderinstance.DefaultUpdatedAt = paymentproviderinstanceDescUpdatedAt.Default.(func() time.Time)
// paymentproviderinstance.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
paymentproviderinstance.UpdateDefaultUpdatedAt = paymentproviderinstanceDescUpdatedAt.UpdateDefault.(func() time.Time)
+ pendingauthsessionMixin := schema.PendingAuthSession{}.Mixin()
+ pendingauthsessionMixinFields0 := pendingauthsessionMixin[0].Fields()
+ _ = pendingauthsessionMixinFields0
+ pendingauthsessionFields := schema.PendingAuthSession{}.Fields()
+ _ = pendingauthsessionFields
+ // pendingauthsessionDescCreatedAt is the schema descriptor for created_at field.
+ pendingauthsessionDescCreatedAt := pendingauthsessionMixinFields0[0].Descriptor()
+ // pendingauthsession.DefaultCreatedAt holds the default value on creation for the created_at field.
+ pendingauthsession.DefaultCreatedAt = pendingauthsessionDescCreatedAt.Default.(func() time.Time)
+ // pendingauthsessionDescUpdatedAt is the schema descriptor for updated_at field.
+ pendingauthsessionDescUpdatedAt := pendingauthsessionMixinFields0[1].Descriptor()
+ // pendingauthsession.DefaultUpdatedAt holds the default value on creation for the updated_at field.
+ pendingauthsession.DefaultUpdatedAt = pendingauthsessionDescUpdatedAt.Default.(func() time.Time)
+ // pendingauthsession.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
+ pendingauthsession.UpdateDefaultUpdatedAt = pendingauthsessionDescUpdatedAt.UpdateDefault.(func() time.Time)
+ // pendingauthsessionDescSessionToken is the schema descriptor for session_token field.
+ pendingauthsessionDescSessionToken := pendingauthsessionFields[0].Descriptor()
+ // pendingauthsession.SessionTokenValidator is a validator for the "session_token" field. It is called by the builders before save.
+ pendingauthsession.SessionTokenValidator = func() func(string) error {
+ validators := pendingauthsessionDescSessionToken.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ }
+ return func(session_token string) error {
+ for _, fn := range fns {
+ if err := fn(session_token); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // pendingauthsessionDescIntent is the schema descriptor for intent field.
+ pendingauthsessionDescIntent := pendingauthsessionFields[1].Descriptor()
+ // pendingauthsession.IntentValidator is a validator for the "intent" field. It is called by the builders before save.
+ pendingauthsession.IntentValidator = func() func(string) error {
+ validators := pendingauthsessionDescIntent.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ validators[2].(func(string) error),
+ }
+ return func(intent string) error {
+ for _, fn := range fns {
+ if err := fn(intent); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // pendingauthsessionDescProviderType is the schema descriptor for provider_type field.
+ pendingauthsessionDescProviderType := pendingauthsessionFields[2].Descriptor()
+ // pendingauthsession.ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save.
+ pendingauthsession.ProviderTypeValidator = func() func(string) error {
+ validators := pendingauthsessionDescProviderType.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ validators[2].(func(string) error),
+ }
+ return func(provider_type string) error {
+ for _, fn := range fns {
+ if err := fn(provider_type); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // pendingauthsessionDescProviderKey is the schema descriptor for provider_key field.
+ pendingauthsessionDescProviderKey := pendingauthsessionFields[3].Descriptor()
+ // pendingauthsession.ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
+ pendingauthsession.ProviderKeyValidator = pendingauthsessionDescProviderKey.Validators[0].(func(string) error)
+ // pendingauthsessionDescProviderSubject is the schema descriptor for provider_subject field.
+ pendingauthsessionDescProviderSubject := pendingauthsessionFields[4].Descriptor()
+ // pendingauthsession.ProviderSubjectValidator is a validator for the "provider_subject" field. It is called by the builders before save.
+ pendingauthsession.ProviderSubjectValidator = pendingauthsessionDescProviderSubject.Validators[0].(func(string) error)
+ // pendingauthsessionDescRedirectTo is the schema descriptor for redirect_to field.
+ pendingauthsessionDescRedirectTo := pendingauthsessionFields[6].Descriptor()
+ // pendingauthsession.DefaultRedirectTo holds the default value on creation for the redirect_to field.
+ pendingauthsession.DefaultRedirectTo = pendingauthsessionDescRedirectTo.Default.(string)
+ // pendingauthsessionDescResolvedEmail is the schema descriptor for resolved_email field.
+ pendingauthsessionDescResolvedEmail := pendingauthsessionFields[7].Descriptor()
+ // pendingauthsession.DefaultResolvedEmail holds the default value on creation for the resolved_email field.
+ pendingauthsession.DefaultResolvedEmail = pendingauthsessionDescResolvedEmail.Default.(string)
+ // pendingauthsessionDescRegistrationPasswordHash is the schema descriptor for registration_password_hash field.
+ pendingauthsessionDescRegistrationPasswordHash := pendingauthsessionFields[8].Descriptor()
+ // pendingauthsession.DefaultRegistrationPasswordHash holds the default value on creation for the registration_password_hash field.
+ pendingauthsession.DefaultRegistrationPasswordHash = pendingauthsessionDescRegistrationPasswordHash.Default.(string)
+ // pendingauthsessionDescUpstreamIdentityClaims is the schema descriptor for upstream_identity_claims field.
+ pendingauthsessionDescUpstreamIdentityClaims := pendingauthsessionFields[9].Descriptor()
+ // pendingauthsession.DefaultUpstreamIdentityClaims holds the default value on creation for the upstream_identity_claims field.
+ pendingauthsession.DefaultUpstreamIdentityClaims = pendingauthsessionDescUpstreamIdentityClaims.Default.(func() map[string]interface{})
+ // pendingauthsessionDescLocalFlowState is the schema descriptor for local_flow_state field.
+ pendingauthsessionDescLocalFlowState := pendingauthsessionFields[10].Descriptor()
+ // pendingauthsession.DefaultLocalFlowState holds the default value on creation for the local_flow_state field.
+ pendingauthsession.DefaultLocalFlowState = pendingauthsessionDescLocalFlowState.Default.(func() map[string]interface{})
+ // pendingauthsessionDescBrowserSessionKey is the schema descriptor for browser_session_key field.
+ pendingauthsessionDescBrowserSessionKey := pendingauthsessionFields[11].Descriptor()
+ // pendingauthsession.DefaultBrowserSessionKey holds the default value on creation for the browser_session_key field.
+ pendingauthsession.DefaultBrowserSessionKey = pendingauthsessionDescBrowserSessionKey.Default.(string)
+ // pendingauthsessionDescCompletionCodeHash is the schema descriptor for completion_code_hash field.
+ pendingauthsessionDescCompletionCodeHash := pendingauthsessionFields[12].Descriptor()
+ // pendingauthsession.DefaultCompletionCodeHash holds the default value on creation for the completion_code_hash field.
+ pendingauthsession.DefaultCompletionCodeHash = pendingauthsessionDescCompletionCodeHash.Default.(string)
promocodeFields := schema.PromoCode{}.Fields()
_ = promocodeFields
// promocodeDescCode is the schema descriptor for code field.
@@ -1297,20 +1549,26 @@ func init() {
userDescTotpEnabled := userFields[9].Descriptor()
// user.DefaultTotpEnabled holds the default value on creation for the totp_enabled field.
user.DefaultTotpEnabled = userDescTotpEnabled.Default.(bool)
+ // userDescSignupSource is the schema descriptor for signup_source field.
+ userDescSignupSource := userFields[11].Descriptor()
+ // user.DefaultSignupSource holds the default value on creation for the signup_source field.
+ user.DefaultSignupSource = userDescSignupSource.Default.(string)
+ // user.SignupSourceValidator is a validator for the "signup_source" field. It is called by the builders before save.
+ user.SignupSourceValidator = userDescSignupSource.Validators[0].(func(string) error)
// userDescBalanceNotifyEnabled is the schema descriptor for balance_notify_enabled field.
- userDescBalanceNotifyEnabled := userFields[11].Descriptor()
+ userDescBalanceNotifyEnabled := userFields[14].Descriptor()
// user.DefaultBalanceNotifyEnabled holds the default value on creation for the balance_notify_enabled field.
user.DefaultBalanceNotifyEnabled = userDescBalanceNotifyEnabled.Default.(bool)
// userDescBalanceNotifyThresholdType is the schema descriptor for balance_notify_threshold_type field.
- userDescBalanceNotifyThresholdType := userFields[12].Descriptor()
+ userDescBalanceNotifyThresholdType := userFields[15].Descriptor()
// user.DefaultBalanceNotifyThresholdType holds the default value on creation for the balance_notify_threshold_type field.
user.DefaultBalanceNotifyThresholdType = userDescBalanceNotifyThresholdType.Default.(string)
// userDescBalanceNotifyExtraEmails is the schema descriptor for balance_notify_extra_emails field.
- userDescBalanceNotifyExtraEmails := userFields[14].Descriptor()
+ userDescBalanceNotifyExtraEmails := userFields[17].Descriptor()
// user.DefaultBalanceNotifyExtraEmails holds the default value on creation for the balance_notify_extra_emails field.
user.DefaultBalanceNotifyExtraEmails = userDescBalanceNotifyExtraEmails.Default.(string)
// userDescTotalRecharged is the schema descriptor for total_recharged field.
- userDescTotalRecharged := userFields[15].Descriptor()
+ userDescTotalRecharged := userFields[18].Descriptor()
// user.DefaultTotalRecharged holds the default value on creation for the total_recharged field.
user.DefaultTotalRecharged = userDescTotalRecharged.Default.(float64)
userallowedgroupFields := schema.UserAllowedGroup{}.Fields()
diff --git a/backend/ent/schema/auth_identity.go b/backend/ent/schema/auth_identity.go
new file mode 100644
index 00000000..e4b9ac90
--- /dev/null
+++ b/backend/ent/schema/auth_identity.go
@@ -0,0 +1,93 @@
+package schema
+
+import (
+ "fmt"
+
+ "github.com/Wei-Shaw/sub2api/ent/schema/mixins"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/entsql"
+ "entgo.io/ent/schema"
+ "entgo.io/ent/schema/edge"
+ "entgo.io/ent/schema/field"
+ "entgo.io/ent/schema/index"
+)
+
+var authProviderTypes = map[string]struct{}{
+ "email": {},
+ "linuxdo": {},
+ "oidc": {},
+ "wechat": {},
+}
+
+func validateAuthProviderType(value string) error {
+ if _, ok := authProviderTypes[value]; ok {
+ return nil
+ }
+ return fmt.Errorf("invalid auth provider type %q", value)
+}
+
+// AuthIdentity stores the canonical login identity for an account.
+type AuthIdentity struct {
+ ent.Schema
+}
+
+func (AuthIdentity) Annotations() []schema.Annotation {
+ return []schema.Annotation{
+ entsql.Annotation{Table: "auth_identities"},
+ }
+}
+
+func (AuthIdentity) Mixin() []ent.Mixin {
+ return []ent.Mixin{
+ mixins.TimeMixin{},
+ }
+}
+
+func (AuthIdentity) Fields() []ent.Field {
+ return []ent.Field{
+ field.Int64("user_id"),
+ field.String("provider_type").
+ MaxLen(20).
+ NotEmpty().
+ Validate(validateAuthProviderType),
+ field.String("provider_key").
+ NotEmpty().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.String("provider_subject").
+ NotEmpty().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.Time("verified_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ field.String("issuer").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.JSON("metadata", map[string]any{}).
+ Default(func() map[string]any { return map[string]any{} }).
+ SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
+ }
+}
+
+func (AuthIdentity) Edges() []ent.Edge {
+ return []ent.Edge{
+ edge.From("user", User.Type).
+ Ref("auth_identities").
+ Field("user_id").
+ Required().
+ Unique(),
+ edge.To("channels", AuthIdentityChannel.Type),
+ edge.To("adoption_decisions", IdentityAdoptionDecision.Type),
+ }
+}
+
+func (AuthIdentity) Indexes() []ent.Index {
+ return []ent.Index{
+ index.Fields("provider_type", "provider_key", "provider_subject").Unique(),
+ index.Fields("user_id"),
+ index.Fields("user_id", "provider_type"),
+ }
+}
diff --git a/backend/ent/schema/auth_identity_channel.go b/backend/ent/schema/auth_identity_channel.go
new file mode 100644
index 00000000..69f2ad02
--- /dev/null
+++ b/backend/ent/schema/auth_identity_channel.go
@@ -0,0 +1,72 @@
+package schema
+
+import (
+ "github.com/Wei-Shaw/sub2api/ent/schema/mixins"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/entsql"
+ "entgo.io/ent/schema"
+ "entgo.io/ent/schema/edge"
+ "entgo.io/ent/schema/field"
+ "entgo.io/ent/schema/index"
+)
+
+// AuthIdentityChannel stores channel-scoped identifiers for a canonical identity.
+type AuthIdentityChannel struct {
+ ent.Schema
+}
+
+func (AuthIdentityChannel) Annotations() []schema.Annotation {
+ return []schema.Annotation{
+ entsql.Annotation{Table: "auth_identity_channels"},
+ }
+}
+
+func (AuthIdentityChannel) Mixin() []ent.Mixin {
+ return []ent.Mixin{
+ mixins.TimeMixin{},
+ }
+}
+
+func (AuthIdentityChannel) Fields() []ent.Field {
+ return []ent.Field{
+ field.Int64("identity_id"),
+ field.String("provider_type").
+ MaxLen(20).
+ NotEmpty().
+ Validate(validateAuthProviderType),
+ field.String("provider_key").
+ NotEmpty().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.String("channel").
+ MaxLen(20).
+ NotEmpty(),
+ field.String("channel_app_id").
+ NotEmpty().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.String("channel_subject").
+ NotEmpty().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.JSON("metadata", map[string]any{}).
+ Default(func() map[string]any { return map[string]any{} }).
+ SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
+ }
+}
+
+func (AuthIdentityChannel) Edges() []ent.Edge {
+ return []ent.Edge{
+ edge.From("identity", AuthIdentity.Type).
+ Ref("channels").
+ Field("identity_id").
+ Required().
+ Unique(),
+ }
+}
+
+func (AuthIdentityChannel) Indexes() []ent.Index {
+ return []ent.Index{
+ index.Fields("provider_type", "provider_key", "channel", "channel_app_id", "channel_subject").Unique(),
+ index.Fields("identity_id"),
+ }
+}
diff --git a/backend/ent/schema/auth_identity_schema_test.go b/backend/ent/schema/auth_identity_schema_test.go
new file mode 100644
index 00000000..de55dd69
--- /dev/null
+++ b/backend/ent/schema/auth_identity_schema_test.go
@@ -0,0 +1,124 @@
+package schema
+
+import (
+ "testing"
+
+ "entgo.io/ent/entc/load"
+ "github.com/stretchr/testify/require"
+)
+
+func TestAuthIdentityFoundationSchemas(t *testing.T) {
+ spec, err := (&load.Config{Path: "."}).Load()
+ require.NoError(t, err)
+
+ schemas := map[string]*load.Schema{}
+ for _, schema := range spec.Schemas {
+ schemas[schema.Name] = schema
+ }
+
+ authIdentity := requireSchema(t, schemas, "AuthIdentity")
+ requireSchemaFields(t, authIdentity,
+ "user_id",
+ "provider_type",
+ "provider_key",
+ "provider_subject",
+ "verified_at",
+ "issuer",
+ "metadata",
+ )
+ requireHasUniqueIndex(t, authIdentity, "provider_type", "provider_key", "provider_subject")
+
+ authIdentityChannel := requireSchema(t, schemas, "AuthIdentityChannel")
+ requireSchemaFields(t, authIdentityChannel,
+ "identity_id",
+ "provider_type",
+ "provider_key",
+ "channel",
+ "channel_app_id",
+ "channel_subject",
+ "metadata",
+ )
+ requireHasUniqueIndex(t, authIdentityChannel, "provider_type", "provider_key", "channel", "channel_app_id", "channel_subject")
+
+ pendingAuthSession := requireSchema(t, schemas, "PendingAuthSession")
+ requireSchemaFields(t, pendingAuthSession,
+ "intent",
+ "provider_type",
+ "provider_key",
+ "provider_subject",
+ "target_user_id",
+ "redirect_to",
+ "resolved_email",
+ "registration_password_hash",
+ "upstream_identity_claims",
+ "local_flow_state",
+ "browser_session_key",
+ "completion_code_hash",
+ "completion_code_expires_at",
+ "email_verified_at",
+ "password_verified_at",
+ "totp_verified_at",
+ "expires_at",
+ "consumed_at",
+ )
+
+ adoptionDecision := requireSchema(t, schemas, "IdentityAdoptionDecision")
+ requireSchemaFields(t, adoptionDecision,
+ "pending_auth_session_id",
+ "identity_id",
+ "adopt_display_name",
+ "adopt_avatar",
+ "decided_at",
+ )
+ requireHasUniqueIndex(t, adoptionDecision, "pending_auth_session_id")
+
+ userSchema := requireSchema(t, schemas, "User")
+ requireSchemaFields(t, userSchema, "signup_source", "last_login_at", "last_active_at")
+}
+
+func requireSchema(t *testing.T, schemas map[string]*load.Schema, name string) *load.Schema {
+ t.Helper()
+
+ schema, ok := schemas[name]
+ require.True(t, ok, "schema %s should exist", name)
+ return schema
+}
+
+func requireSchemaFields(t *testing.T, schema *load.Schema, names ...string) {
+ t.Helper()
+
+ fields := map[string]struct{}{}
+ for _, field := range schema.Fields {
+ fields[field.Name] = struct{}{}
+ }
+
+ for _, name := range names {
+ _, ok := fields[name]
+ require.True(t, ok, "schema %s should include field %s", schema.Name, name)
+ }
+}
+
+func requireHasUniqueIndex(t *testing.T, schema *load.Schema, fields ...string) {
+ t.Helper()
+
+ for _, index := range schema.Indexes {
+ if !index.Unique {
+ continue
+ }
+ if len(index.Fields) != len(fields) {
+ continue
+ }
+ match := true
+ for i := range fields {
+ if index.Fields[i] != fields[i] {
+ match = false
+ break
+ }
+ }
+ if match {
+ return
+ }
+ }
+
+ require.Failf(t, "missing unique index", "schema %s should include unique index on %v", schema.Name, fields)
+}
diff --git a/backend/ent/schema/identity_adoption_decision.go b/backend/ent/schema/identity_adoption_decision.go
new file mode 100644
index 00000000..9fdd26fb
--- /dev/null
+++ b/backend/ent/schema/identity_adoption_decision.go
@@ -0,0 +1,70 @@
+package schema
+
+import (
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/ent/schema/mixins"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/entsql"
+ "entgo.io/ent/schema"
+ "entgo.io/ent/schema/edge"
+ "entgo.io/ent/schema/field"
+ "entgo.io/ent/schema/index"
+)
+
+// IdentityAdoptionDecision stores the one-time profile adoption choice captured during a pending auth flow.
+type IdentityAdoptionDecision struct {
+ ent.Schema
+}
+
+func (IdentityAdoptionDecision) Annotations() []schema.Annotation {
+ return []schema.Annotation{
+ entsql.Annotation{Table: "identity_adoption_decisions"},
+ }
+}
+
+func (IdentityAdoptionDecision) Mixin() []ent.Mixin {
+ return []ent.Mixin{
+ mixins.TimeMixin{},
+ }
+}
+
+func (IdentityAdoptionDecision) Fields() []ent.Field {
+ return []ent.Field{
+ field.Int64("pending_auth_session_id"),
+ field.Int64("identity_id").
+ Optional().
+ Nillable(),
+ field.Bool("adopt_display_name").
+ Default(false),
+ field.Bool("adopt_avatar").
+ Default(false),
+ field.Time("decided_at").
+ Immutable().
+ Default(time.Now).
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ }
+}
+
+func (IdentityAdoptionDecision) Edges() []ent.Edge {
+ return []ent.Edge{
+ edge.From("pending_auth_session", PendingAuthSession.Type).
+ Ref("adoption_decision").
+ Field("pending_auth_session_id").
+ Required().
+ Unique(),
+ edge.From("identity", AuthIdentity.Type).
+ Ref("adoption_decisions").
+ Field("identity_id").
+ Unique(),
+ }
+}
+
+func (IdentityAdoptionDecision) Indexes() []ent.Index {
+ return []ent.Index{
+ index.Fields("pending_auth_session_id").Unique(),
+ index.Fields("identity_id"),
+ }
+}
diff --git a/backend/ent/schema/pending_auth_session.go b/backend/ent/schema/pending_auth_session.go
new file mode 100644
index 00000000..91341d49
--- /dev/null
+++ b/backend/ent/schema/pending_auth_session.go
@@ -0,0 +1,134 @@
+package schema
+
+import (
+ "fmt"
+
+ "github.com/Wei-Shaw/sub2api/ent/schema/mixins"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/entsql"
+ "entgo.io/ent/schema"
+ "entgo.io/ent/schema/edge"
+ "entgo.io/ent/schema/field"
+ "entgo.io/ent/schema/index"
+)
+
+var pendingAuthIntents = map[string]struct{}{
+ "login": {},
+ "bind_current_user": {},
+ "adopt_existing_user_by_email": {},
+}
+
+func validatePendingAuthIntent(value string) error {
+ if _, ok := pendingAuthIntents[value]; ok {
+ return nil
+ }
+ return fmt.Errorf("invalid pending auth intent %q", value)
+}
+
+// PendingAuthSession stores a short-lived post-auth decision session.
+type PendingAuthSession struct {
+ ent.Schema
+}
+
+func (PendingAuthSession) Annotations() []schema.Annotation {
+ return []schema.Annotation{
+ entsql.Annotation{Table: "pending_auth_sessions"},
+ }
+}
+
+func (PendingAuthSession) Mixin() []ent.Mixin {
+ return []ent.Mixin{
+ mixins.TimeMixin{},
+ }
+}
+
+func (PendingAuthSession) Fields() []ent.Field {
+ return []ent.Field{
+ field.String("session_token").
+ MaxLen(255).
+ NotEmpty(),
+ field.String("intent").
+ MaxLen(40).
+ NotEmpty().
+ Validate(validatePendingAuthIntent),
+ field.String("provider_type").
+ MaxLen(20).
+ NotEmpty().
+ Validate(validateAuthProviderType),
+ field.String("provider_key").
+ NotEmpty().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.String("provider_subject").
+ NotEmpty().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.Int64("target_user_id").
+ Optional().
+ Nillable(),
+ field.String("redirect_to").
+ Default("").
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.String("resolved_email").
+ Default("").
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.String("registration_password_hash").
+ Default("").
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.JSON("upstream_identity_claims", map[string]any{}).
+ Default(func() map[string]any { return map[string]any{} }).
+ SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
+ field.JSON("local_flow_state", map[string]any{}).
+ Default(func() map[string]any { return map[string]any{} }).
+ SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
+ field.String("browser_session_key").
+ Default("").
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.String("completion_code_hash").
+ Default("").
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.Time("completion_code_expires_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ field.Time("email_verified_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ field.Time("password_verified_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ field.Time("totp_verified_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ field.Time("expires_at").
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ field.Time("consumed_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ }
+}
+
+func (PendingAuthSession) Edges() []ent.Edge {
+ return []ent.Edge{
+ edge.From("target_user", User.Type).
+ Ref("pending_auth_sessions").
+ Field("target_user_id").
+ Unique(),
+ edge.To("adoption_decision", IdentityAdoptionDecision.Type).
+ Unique(),
+ }
+}
+
+func (PendingAuthSession) Indexes() []ent.Index {
+ return []ent.Index{
+ index.Fields("session_token").Unique(),
+ index.Fields("target_user_id"),
+ index.Fields("expires_at"),
+ index.Fields("provider_type", "provider_key", "provider_subject"),
+ index.Fields("completion_code_hash"),
+ }
+}
diff --git a/backend/ent/schema/user.go b/backend/ent/schema/user.go
index ef52e985..bb58d9e3 100644
--- a/backend/ent/schema/user.go
+++ b/backend/ent/schema/user.go
@@ -72,6 +72,17 @@ func (User) Fields() []ent.Field {
field.Time("totp_enabled_at").
Optional().
Nillable(),
+ field.String("signup_source").
+ MaxLen(20).
+ Default("email"),
+ field.Time("last_login_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ field.Time("last_active_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
// 余额不足通知
field.Bool("balance_notify_enabled").
@@ -104,6 +115,8 @@ func (User) Edges() []ent.Edge {
edge.To("attribute_values", UserAttributeValue.Type),
edge.To("promo_code_usages", PromoCodeUsage.Type),
edge.To("payment_orders", PaymentOrder.Type),
+ edge.To("auth_identities", AuthIdentity.Type),
+ edge.To("pending_auth_sessions", PendingAuthSession.Type),
}
}
diff --git a/backend/ent/tx.go b/backend/ent/tx.go
index bb3139d5..bde3e35b 100644
--- a/backend/ent/tx.go
+++ b/backend/ent/tx.go
@@ -24,18 +24,26 @@ type Tx struct {
Announcement *AnnouncementClient
// AnnouncementRead is the client for interacting with the AnnouncementRead builders.
AnnouncementRead *AnnouncementReadClient
+ // AuthIdentity is the client for interacting with the AuthIdentity builders.
+ AuthIdentity *AuthIdentityClient
+ // AuthIdentityChannel is the client for interacting with the AuthIdentityChannel builders.
+ AuthIdentityChannel *AuthIdentityChannelClient
// ErrorPassthroughRule is the client for interacting with the ErrorPassthroughRule builders.
ErrorPassthroughRule *ErrorPassthroughRuleClient
// Group is the client for interacting with the Group builders.
Group *GroupClient
// IdempotencyRecord is the client for interacting with the IdempotencyRecord builders.
IdempotencyRecord *IdempotencyRecordClient
+ // IdentityAdoptionDecision is the client for interacting with the IdentityAdoptionDecision builders.
+ IdentityAdoptionDecision *IdentityAdoptionDecisionClient
// PaymentAuditLog is the client for interacting with the PaymentAuditLog builders.
PaymentAuditLog *PaymentAuditLogClient
// PaymentOrder is the client for interacting with the PaymentOrder builders.
PaymentOrder *PaymentOrderClient
// PaymentProviderInstance is the client for interacting with the PaymentProviderInstance builders.
PaymentProviderInstance *PaymentProviderInstanceClient
+ // PendingAuthSession is the client for interacting with the PendingAuthSession builders.
+ PendingAuthSession *PendingAuthSessionClient
// PromoCode is the client for interacting with the PromoCode builders.
PromoCode *PromoCodeClient
// PromoCodeUsage is the client for interacting with the PromoCodeUsage builders.
@@ -202,12 +210,16 @@ func (tx *Tx) init() {
tx.AccountGroup = NewAccountGroupClient(tx.config)
tx.Announcement = NewAnnouncementClient(tx.config)
tx.AnnouncementRead = NewAnnouncementReadClient(tx.config)
+ tx.AuthIdentity = NewAuthIdentityClient(tx.config)
+ tx.AuthIdentityChannel = NewAuthIdentityChannelClient(tx.config)
tx.ErrorPassthroughRule = NewErrorPassthroughRuleClient(tx.config)
tx.Group = NewGroupClient(tx.config)
tx.IdempotencyRecord = NewIdempotencyRecordClient(tx.config)
+ tx.IdentityAdoptionDecision = NewIdentityAdoptionDecisionClient(tx.config)
tx.PaymentAuditLog = NewPaymentAuditLogClient(tx.config)
tx.PaymentOrder = NewPaymentOrderClient(tx.config)
tx.PaymentProviderInstance = NewPaymentProviderInstanceClient(tx.config)
+ tx.PendingAuthSession = NewPendingAuthSessionClient(tx.config)
tx.PromoCode = NewPromoCodeClient(tx.config)
tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config)
tx.Proxy = NewProxyClient(tx.config)
diff --git a/backend/ent/user.go b/backend/ent/user.go
index 9fa91f74..66f33623 100644
--- a/backend/ent/user.go
+++ b/backend/ent/user.go
@@ -45,6 +45,12 @@ type User struct {
TotpEnabled bool `json:"totp_enabled,omitempty"`
// TotpEnabledAt holds the value of the "totp_enabled_at" field.
TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"`
+ // SignupSource holds the value of the "signup_source" field.
+ SignupSource string `json:"signup_source,omitempty"`
+ // LastLoginAt holds the value of the "last_login_at" field.
+ LastLoginAt *time.Time `json:"last_login_at,omitempty"`
+ // LastActiveAt holds the value of the "last_active_at" field.
+ LastActiveAt *time.Time `json:"last_active_at,omitempty"`
// BalanceNotifyEnabled holds the value of the "balance_notify_enabled" field.
BalanceNotifyEnabled bool `json:"balance_notify_enabled,omitempty"`
// BalanceNotifyThresholdType holds the value of the "balance_notify_threshold_type" field.
@@ -83,11 +89,15 @@ type UserEdges struct {
PromoCodeUsages []*PromoCodeUsage `json:"promo_code_usages,omitempty"`
// PaymentOrders holds the value of the payment_orders edge.
PaymentOrders []*PaymentOrder `json:"payment_orders,omitempty"`
+ // AuthIdentities holds the value of the auth_identities edge.
+ AuthIdentities []*AuthIdentity `json:"auth_identities,omitempty"`
+ // PendingAuthSessions holds the value of the pending_auth_sessions edge.
+ PendingAuthSessions []*PendingAuthSession `json:"pending_auth_sessions,omitempty"`
// UserAllowedGroups holds the value of the user_allowed_groups edge.
UserAllowedGroups []*UserAllowedGroup `json:"user_allowed_groups,omitempty"`
// loadedTypes holds the information for reporting if a
// type was loaded (or requested) in eager-loading or not.
- loadedTypes [11]bool
+ loadedTypes [13]bool
}
// APIKeysOrErr returns the APIKeys value or an error if the edge
@@ -180,10 +190,28 @@ func (e UserEdges) PaymentOrdersOrErr() ([]*PaymentOrder, error) {
return nil, &NotLoadedError{edge: "payment_orders"}
}
+// AuthIdentitiesOrErr returns the AuthIdentities value or an error if the edge
+// was not loaded in eager-loading.
+func (e UserEdges) AuthIdentitiesOrErr() ([]*AuthIdentity, error) {
+ if e.loadedTypes[10] {
+ return e.AuthIdentities, nil
+ }
+ return nil, &NotLoadedError{edge: "auth_identities"}
+}
+
+// PendingAuthSessionsOrErr returns the PendingAuthSessions value or an error if the edge
+// was not loaded in eager-loading.
+func (e UserEdges) PendingAuthSessionsOrErr() ([]*PendingAuthSession, error) {
+ if e.loadedTypes[11] {
+ return e.PendingAuthSessions, nil
+ }
+ return nil, &NotLoadedError{edge: "pending_auth_sessions"}
+}
+
// UserAllowedGroupsOrErr returns the UserAllowedGroups value or an error if the edge
// was not loaded in eager-loading.
func (e UserEdges) UserAllowedGroupsOrErr() ([]*UserAllowedGroup, error) {
- if e.loadedTypes[10] {
+ if e.loadedTypes[12] {
return e.UserAllowedGroups, nil
}
return nil, &NotLoadedError{edge: "user_allowed_groups"}
@@ -200,9 +228,9 @@ func (*User) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullFloat64)
case user.FieldID, user.FieldConcurrency:
values[i] = new(sql.NullInt64)
- case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted, user.FieldBalanceNotifyThresholdType, user.FieldBalanceNotifyExtraEmails:
+ case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted, user.FieldSignupSource, user.FieldBalanceNotifyThresholdType, user.FieldBalanceNotifyExtraEmails:
values[i] = new(sql.NullString)
- case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt, user.FieldTotpEnabledAt:
+ case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt, user.FieldTotpEnabledAt, user.FieldLastLoginAt, user.FieldLastActiveAt:
values[i] = new(sql.NullTime)
default:
values[i] = new(sql.UnknownType)
@@ -312,6 +340,26 @@ func (_m *User) assignValues(columns []string, values []any) error {
_m.TotpEnabledAt = new(time.Time)
*_m.TotpEnabledAt = value.Time
}
+ case user.FieldSignupSource:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field signup_source", values[i])
+ } else if value.Valid {
+ _m.SignupSource = value.String
+ }
+ case user.FieldLastLoginAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field last_login_at", values[i])
+ } else if value.Valid {
+ _m.LastLoginAt = new(time.Time)
+ *_m.LastLoginAt = value.Time
+ }
+ case user.FieldLastActiveAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field last_active_at", values[i])
+ } else if value.Valid {
+ _m.LastActiveAt = new(time.Time)
+ *_m.LastActiveAt = value.Time
+ }
case user.FieldBalanceNotifyEnabled:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field balance_notify_enabled", values[i])
@@ -406,6 +454,16 @@ func (_m *User) QueryPaymentOrders() *PaymentOrderQuery {
return NewUserClient(_m.config).QueryPaymentOrders(_m)
}
+// QueryAuthIdentities queries the "auth_identities" edge of the User entity.
+func (_m *User) QueryAuthIdentities() *AuthIdentityQuery {
+ return NewUserClient(_m.config).QueryAuthIdentities(_m)
+}
+
+// QueryPendingAuthSessions queries the "pending_auth_sessions" edge of the User entity.
+func (_m *User) QueryPendingAuthSessions() *PendingAuthSessionQuery {
+ return NewUserClient(_m.config).QueryPendingAuthSessions(_m)
+}
+
// QueryUserAllowedGroups queries the "user_allowed_groups" edge of the User entity.
func (_m *User) QueryUserAllowedGroups() *UserAllowedGroupQuery {
return NewUserClient(_m.config).QueryUserAllowedGroups(_m)
@@ -482,6 +540,19 @@ func (_m *User) String() string {
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteString(", ")
+ builder.WriteString("signup_source=")
+ builder.WriteString(_m.SignupSource)
+ builder.WriteString(", ")
+ if v := _m.LastLoginAt; v != nil {
+ builder.WriteString("last_login_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
+ if v := _m.LastActiveAt; v != nil {
+ builder.WriteString("last_active_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
builder.WriteString("balance_notify_enabled=")
builder.WriteString(fmt.Sprintf("%v", _m.BalanceNotifyEnabled))
builder.WriteString(", ")
diff --git a/backend/ent/user/user.go b/backend/ent/user/user.go
index d88a3a38..567e3b14 100644
--- a/backend/ent/user/user.go
+++ b/backend/ent/user/user.go
@@ -43,6 +43,12 @@ const (
FieldTotpEnabled = "totp_enabled"
// FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database.
FieldTotpEnabledAt = "totp_enabled_at"
+ // FieldSignupSource holds the string denoting the signup_source field in the database.
+ FieldSignupSource = "signup_source"
+ // FieldLastLoginAt holds the string denoting the last_login_at field in the database.
+ FieldLastLoginAt = "last_login_at"
+ // FieldLastActiveAt holds the string denoting the last_active_at field in the database.
+ FieldLastActiveAt = "last_active_at"
// FieldBalanceNotifyEnabled holds the string denoting the balance_notify_enabled field in the database.
FieldBalanceNotifyEnabled = "balance_notify_enabled"
// FieldBalanceNotifyThresholdType holds the string denoting the balance_notify_threshold_type field in the database.
@@ -73,6 +79,10 @@ const (
EdgePromoCodeUsages = "promo_code_usages"
// EdgePaymentOrders holds the string denoting the payment_orders edge name in mutations.
EdgePaymentOrders = "payment_orders"
+ // EdgeAuthIdentities holds the string denoting the auth_identities edge name in mutations.
+ EdgeAuthIdentities = "auth_identities"
+ // EdgePendingAuthSessions holds the string denoting the pending_auth_sessions edge name in mutations.
+ EdgePendingAuthSessions = "pending_auth_sessions"
// EdgeUserAllowedGroups holds the string denoting the user_allowed_groups edge name in mutations.
EdgeUserAllowedGroups = "user_allowed_groups"
// Table holds the table name of the user in the database.
@@ -145,6 +155,20 @@ const (
PaymentOrdersInverseTable = "payment_orders"
// PaymentOrdersColumn is the table column denoting the payment_orders relation/edge.
PaymentOrdersColumn = "user_id"
+ // AuthIdentitiesTable is the table that holds the auth_identities relation/edge.
+ AuthIdentitiesTable = "auth_identities"
+ // AuthIdentitiesInverseTable is the table name for the AuthIdentity entity.
+ // It exists in this package in order to avoid circular dependency with the "authidentity" package.
+ AuthIdentitiesInverseTable = "auth_identities"
+ // AuthIdentitiesColumn is the table column denoting the auth_identities relation/edge.
+ AuthIdentitiesColumn = "user_id"
+ // PendingAuthSessionsTable is the table that holds the pending_auth_sessions relation/edge.
+ PendingAuthSessionsTable = "pending_auth_sessions"
+ // PendingAuthSessionsInverseTable is the table name for the PendingAuthSession entity.
+ // It exists in this package in order to avoid circular dependency with the "pendingauthsession" package.
+ PendingAuthSessionsInverseTable = "pending_auth_sessions"
+ // PendingAuthSessionsColumn is the table column denoting the pending_auth_sessions relation/edge.
+ PendingAuthSessionsColumn = "target_user_id"
// UserAllowedGroupsTable is the table that holds the user_allowed_groups relation/edge.
UserAllowedGroupsTable = "user_allowed_groups"
// UserAllowedGroupsInverseTable is the table name for the UserAllowedGroup entity.
@@ -171,6 +195,9 @@ var Columns = []string{
FieldTotpSecretEncrypted,
FieldTotpEnabled,
FieldTotpEnabledAt,
+ FieldSignupSource,
+ FieldLastLoginAt,
+ FieldLastActiveAt,
FieldBalanceNotifyEnabled,
FieldBalanceNotifyThresholdType,
FieldBalanceNotifyThreshold,
@@ -232,6 +259,10 @@ var (
DefaultNotes string
// DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field.
DefaultTotpEnabled bool
+ // DefaultSignupSource holds the default value on creation for the "signup_source" field.
+ DefaultSignupSource string
+ // SignupSourceValidator is a validator for the "signup_source" field. It is called by the builders before save.
+ SignupSourceValidator func(string) error
// DefaultBalanceNotifyEnabled holds the default value on creation for the "balance_notify_enabled" field.
DefaultBalanceNotifyEnabled bool
// DefaultBalanceNotifyThresholdType holds the default value on creation for the "balance_notify_threshold_type" field.
@@ -320,6 +351,21 @@ func ByTotpEnabledAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc()
}
+// BySignupSource orders the results by the signup_source field.
+func BySignupSource(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldSignupSource, opts...).ToFunc()
+}
+
+// ByLastLoginAt orders the results by the last_login_at field.
+func ByLastLoginAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldLastLoginAt, opts...).ToFunc()
+}
+
+// ByLastActiveAt orders the results by the last_active_at field.
+func ByLastActiveAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldLastActiveAt, opts...).ToFunc()
+}
+
// ByBalanceNotifyEnabled orders the results by the balance_notify_enabled field.
func ByBalanceNotifyEnabled(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldBalanceNotifyEnabled, opts...).ToFunc()
@@ -485,6 +531,34 @@ func ByPaymentOrders(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
}
}
+// ByAuthIdentitiesCount orders the results by auth_identities count.
+func ByAuthIdentitiesCount(opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborsCount(s, newAuthIdentitiesStep(), opts...)
+ }
+}
+
+// ByAuthIdentities orders the results by auth_identities terms.
+func ByAuthIdentities(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newAuthIdentitiesStep(), append([]sql.OrderTerm{term}, terms...)...)
+ }
+}
+
+// ByPendingAuthSessionsCount orders the results by pending_auth_sessions count.
+func ByPendingAuthSessionsCount(opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborsCount(s, newPendingAuthSessionsStep(), opts...)
+ }
+}
+
+// ByPendingAuthSessions orders the results by pending_auth_sessions terms.
+func ByPendingAuthSessions(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newPendingAuthSessionsStep(), append([]sql.OrderTerm{term}, terms...)...)
+ }
+}
+
// ByUserAllowedGroupsCount orders the results by user_allowed_groups count.
func ByUserAllowedGroupsCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
@@ -568,6 +642,20 @@ func newPaymentOrdersStep() *sqlgraph.Step {
sqlgraph.Edge(sqlgraph.O2M, false, PaymentOrdersTable, PaymentOrdersColumn),
)
}
+func newAuthIdentitiesStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(AuthIdentitiesInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, AuthIdentitiesTable, AuthIdentitiesColumn),
+ )
+}
+func newPendingAuthSessionsStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(PendingAuthSessionsInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, PendingAuthSessionsTable, PendingAuthSessionsColumn),
+ )
+}
func newUserAllowedGroupsStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
diff --git a/backend/ent/user/where.go b/backend/ent/user/where.go
index 2788aa7a..cbcfcc26 100644
--- a/backend/ent/user/where.go
+++ b/backend/ent/user/where.go
@@ -125,6 +125,21 @@ func TotpEnabledAt(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v))
}
+// SignupSource applies equality check predicate on the "signup_source" field. It's identical to SignupSourceEQ.
+func SignupSource(v string) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldSignupSource, v))
+}
+
+// LastLoginAt applies equality check predicate on the "last_login_at" field. It's identical to LastLoginAtEQ.
+func LastLoginAt(v time.Time) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldLastLoginAt, v))
+}
+
+// LastActiveAt applies equality check predicate on the "last_active_at" field. It's identical to LastActiveAtEQ.
+func LastActiveAt(v time.Time) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldLastActiveAt, v))
+}
+
// BalanceNotifyEnabled applies equality check predicate on the "balance_notify_enabled" field. It's identical to BalanceNotifyEnabledEQ.
func BalanceNotifyEnabled(v bool) predicate.User {
return predicate.User(sql.FieldEQ(FieldBalanceNotifyEnabled, v))
@@ -885,6 +900,171 @@ func TotpEnabledAtNotNil() predicate.User {
return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt))
}
+// SignupSourceEQ applies the EQ predicate on the "signup_source" field.
+func SignupSourceEQ(v string) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldSignupSource, v))
+}
+
+// SignupSourceNEQ applies the NEQ predicate on the "signup_source" field.
+func SignupSourceNEQ(v string) predicate.User {
+ return predicate.User(sql.FieldNEQ(FieldSignupSource, v))
+}
+
+// SignupSourceIn applies the In predicate on the "signup_source" field.
+func SignupSourceIn(vs ...string) predicate.User {
+ return predicate.User(sql.FieldIn(FieldSignupSource, vs...))
+}
+
+// SignupSourceNotIn applies the NotIn predicate on the "signup_source" field.
+func SignupSourceNotIn(vs ...string) predicate.User {
+ return predicate.User(sql.FieldNotIn(FieldSignupSource, vs...))
+}
+
+// SignupSourceGT applies the GT predicate on the "signup_source" field.
+func SignupSourceGT(v string) predicate.User {
+ return predicate.User(sql.FieldGT(FieldSignupSource, v))
+}
+
+// SignupSourceGTE applies the GTE predicate on the "signup_source" field.
+func SignupSourceGTE(v string) predicate.User {
+ return predicate.User(sql.FieldGTE(FieldSignupSource, v))
+}
+
+// SignupSourceLT applies the LT predicate on the "signup_source" field.
+func SignupSourceLT(v string) predicate.User {
+ return predicate.User(sql.FieldLT(FieldSignupSource, v))
+}
+
+// SignupSourceLTE applies the LTE predicate on the "signup_source" field.
+func SignupSourceLTE(v string) predicate.User {
+ return predicate.User(sql.FieldLTE(FieldSignupSource, v))
+}
+
+// SignupSourceContains applies the Contains predicate on the "signup_source" field.
+func SignupSourceContains(v string) predicate.User {
+ return predicate.User(sql.FieldContains(FieldSignupSource, v))
+}
+
+// SignupSourceHasPrefix applies the HasPrefix predicate on the "signup_source" field.
+func SignupSourceHasPrefix(v string) predicate.User {
+ return predicate.User(sql.FieldHasPrefix(FieldSignupSource, v))
+}
+
+// SignupSourceHasSuffix applies the HasSuffix predicate on the "signup_source" field.
+func SignupSourceHasSuffix(v string) predicate.User {
+ return predicate.User(sql.FieldHasSuffix(FieldSignupSource, v))
+}
+
+// SignupSourceEqualFold applies the EqualFold predicate on the "signup_source" field.
+func SignupSourceEqualFold(v string) predicate.User {
+ return predicate.User(sql.FieldEqualFold(FieldSignupSource, v))
+}
+
+// SignupSourceContainsFold applies the ContainsFold predicate on the "signup_source" field.
+func SignupSourceContainsFold(v string) predicate.User {
+ return predicate.User(sql.FieldContainsFold(FieldSignupSource, v))
+}
+
+// LastLoginAtEQ applies the EQ predicate on the "last_login_at" field.
+func LastLoginAtEQ(v time.Time) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldLastLoginAt, v))
+}
+
+// LastLoginAtNEQ applies the NEQ predicate on the "last_login_at" field.
+func LastLoginAtNEQ(v time.Time) predicate.User {
+ return predicate.User(sql.FieldNEQ(FieldLastLoginAt, v))
+}
+
+// LastLoginAtIn applies the In predicate on the "last_login_at" field.
+func LastLoginAtIn(vs ...time.Time) predicate.User {
+ return predicate.User(sql.FieldIn(FieldLastLoginAt, vs...))
+}
+
+// LastLoginAtNotIn applies the NotIn predicate on the "last_login_at" field.
+func LastLoginAtNotIn(vs ...time.Time) predicate.User {
+ return predicate.User(sql.FieldNotIn(FieldLastLoginAt, vs...))
+}
+
+// LastLoginAtGT applies the GT predicate on the "last_login_at" field.
+func LastLoginAtGT(v time.Time) predicate.User {
+ return predicate.User(sql.FieldGT(FieldLastLoginAt, v))
+}
+
+// LastLoginAtGTE applies the GTE predicate on the "last_login_at" field.
+func LastLoginAtGTE(v time.Time) predicate.User {
+ return predicate.User(sql.FieldGTE(FieldLastLoginAt, v))
+}
+
+// LastLoginAtLT applies the LT predicate on the "last_login_at" field.
+func LastLoginAtLT(v time.Time) predicate.User {
+ return predicate.User(sql.FieldLT(FieldLastLoginAt, v))
+}
+
+// LastLoginAtLTE applies the LTE predicate on the "last_login_at" field.
+func LastLoginAtLTE(v time.Time) predicate.User {
+ return predicate.User(sql.FieldLTE(FieldLastLoginAt, v))
+}
+
+// LastLoginAtIsNil applies the IsNil predicate on the "last_login_at" field.
+func LastLoginAtIsNil() predicate.User {
+ return predicate.User(sql.FieldIsNull(FieldLastLoginAt))
+}
+
+// LastLoginAtNotNil applies the NotNil predicate on the "last_login_at" field.
+func LastLoginAtNotNil() predicate.User {
+ return predicate.User(sql.FieldNotNull(FieldLastLoginAt))
+}
+
+// LastActiveAtEQ applies the EQ predicate on the "last_active_at" field.
+func LastActiveAtEQ(v time.Time) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldLastActiveAt, v))
+}
+
+// LastActiveAtNEQ applies the NEQ predicate on the "last_active_at" field.
+func LastActiveAtNEQ(v time.Time) predicate.User {
+ return predicate.User(sql.FieldNEQ(FieldLastActiveAt, v))
+}
+
+// LastActiveAtIn applies the In predicate on the "last_active_at" field.
+func LastActiveAtIn(vs ...time.Time) predicate.User {
+ return predicate.User(sql.FieldIn(FieldLastActiveAt, vs...))
+}
+
+// LastActiveAtNotIn applies the NotIn predicate on the "last_active_at" field.
+func LastActiveAtNotIn(vs ...time.Time) predicate.User {
+ return predicate.User(sql.FieldNotIn(FieldLastActiveAt, vs...))
+}
+
+// LastActiveAtGT applies the GT predicate on the "last_active_at" field.
+func LastActiveAtGT(v time.Time) predicate.User {
+ return predicate.User(sql.FieldGT(FieldLastActiveAt, v))
+}
+
+// LastActiveAtGTE applies the GTE predicate on the "last_active_at" field.
+func LastActiveAtGTE(v time.Time) predicate.User {
+ return predicate.User(sql.FieldGTE(FieldLastActiveAt, v))
+}
+
+// LastActiveAtLT applies the LT predicate on the "last_active_at" field.
+func LastActiveAtLT(v time.Time) predicate.User {
+ return predicate.User(sql.FieldLT(FieldLastActiveAt, v))
+}
+
+// LastActiveAtLTE applies the LTE predicate on the "last_active_at" field.
+func LastActiveAtLTE(v time.Time) predicate.User {
+ return predicate.User(sql.FieldLTE(FieldLastActiveAt, v))
+}
+
+// LastActiveAtIsNil applies the IsNil predicate on the "last_active_at" field.
+func LastActiveAtIsNil() predicate.User {
+ return predicate.User(sql.FieldIsNull(FieldLastActiveAt))
+}
+
+// LastActiveAtNotNil applies the NotNil predicate on the "last_active_at" field.
+func LastActiveAtNotNil() predicate.User {
+ return predicate.User(sql.FieldNotNull(FieldLastActiveAt))
+}
+
// BalanceNotifyEnabledEQ applies the EQ predicate on the "balance_notify_enabled" field.
func BalanceNotifyEnabledEQ(v bool) predicate.User {
return predicate.User(sql.FieldEQ(FieldBalanceNotifyEnabled, v))
@@ -1345,6 +1525,52 @@ func HasPaymentOrdersWith(preds ...predicate.PaymentOrder) predicate.User {
})
}
+// HasAuthIdentities applies the HasEdge predicate on the "auth_identities" edge.
+func HasAuthIdentities() predicate.User {
+ return predicate.User(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, AuthIdentitiesTable, AuthIdentitiesColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasAuthIdentitiesWith applies the HasEdge predicate on the "auth_identities" edge with a given conditions (other predicates).
+func HasAuthIdentitiesWith(preds ...predicate.AuthIdentity) predicate.User {
+ return predicate.User(func(s *sql.Selector) {
+ step := newAuthIdentitiesStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// HasPendingAuthSessions applies the HasEdge predicate on the "pending_auth_sessions" edge.
+func HasPendingAuthSessions() predicate.User {
+ return predicate.User(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, PendingAuthSessionsTable, PendingAuthSessionsColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasPendingAuthSessionsWith applies the HasEdge predicate on the "pending_auth_sessions" edge with a given conditions (other predicates).
+func HasPendingAuthSessionsWith(preds ...predicate.PendingAuthSession) predicate.User {
+ return predicate.User(func(s *sql.Selector) {
+ step := newPendingAuthSessionsStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
// HasUserAllowedGroups applies the HasEdge predicate on the "user_allowed_groups" edge.
func HasUserAllowedGroups() predicate.User {
return predicate.User(func(s *sql.Selector) {
diff --git a/backend/ent/user_create.go b/backend/ent/user_create.go
index fbc64f9c..db95e813 100644
--- a/backend/ent/user_create.go
+++ b/backend/ent/user_create.go
@@ -13,8 +13,10 @@ import (
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
@@ -211,6 +213,48 @@ func (_c *UserCreate) SetNillableTotpEnabledAt(v *time.Time) *UserCreate {
return _c
}
+// SetSignupSource sets the "signup_source" field.
+func (_c *UserCreate) SetSignupSource(v string) *UserCreate {
+ _c.mutation.SetSignupSource(v)
+ return _c
+}
+
+// SetNillableSignupSource sets the "signup_source" field if the given value is not nil.
+func (_c *UserCreate) SetNillableSignupSource(v *string) *UserCreate {
+ if v != nil {
+ _c.SetSignupSource(*v)
+ }
+ return _c
+}
+
+// SetLastLoginAt sets the "last_login_at" field.
+func (_c *UserCreate) SetLastLoginAt(v time.Time) *UserCreate {
+ _c.mutation.SetLastLoginAt(v)
+ return _c
+}
+
+// SetNillableLastLoginAt sets the "last_login_at" field if the given value is not nil.
+func (_c *UserCreate) SetNillableLastLoginAt(v *time.Time) *UserCreate {
+ if v != nil {
+ _c.SetLastLoginAt(*v)
+ }
+ return _c
+}
+
+// SetLastActiveAt sets the "last_active_at" field.
+func (_c *UserCreate) SetLastActiveAt(v time.Time) *UserCreate {
+ _c.mutation.SetLastActiveAt(v)
+ return _c
+}
+
+// SetNillableLastActiveAt sets the "last_active_at" field if the given value is not nil.
+func (_c *UserCreate) SetNillableLastActiveAt(v *time.Time) *UserCreate {
+ if v != nil {
+ _c.SetLastActiveAt(*v)
+ }
+ return _c
+}
+
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (_c *UserCreate) SetBalanceNotifyEnabled(v bool) *UserCreate {
_c.mutation.SetBalanceNotifyEnabled(v)
@@ -431,6 +475,36 @@ func (_c *UserCreate) AddPaymentOrders(v ...*PaymentOrder) *UserCreate {
return _c.AddPaymentOrderIDs(ids...)
}
+// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by IDs.
+func (_c *UserCreate) AddAuthIdentityIDs(ids ...int64) *UserCreate {
+ _c.mutation.AddAuthIdentityIDs(ids...)
+ return _c
+}
+
+// AddAuthIdentities adds the "auth_identities" edges to the AuthIdentity entity.
+func (_c *UserCreate) AddAuthIdentities(v ...*AuthIdentity) *UserCreate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _c.AddAuthIdentityIDs(ids...)
+}
+
+// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs.
+func (_c *UserCreate) AddPendingAuthSessionIDs(ids ...int64) *UserCreate {
+ _c.mutation.AddPendingAuthSessionIDs(ids...)
+ return _c
+}
+
+// AddPendingAuthSessions adds the "pending_auth_sessions" edges to the PendingAuthSession entity.
+func (_c *UserCreate) AddPendingAuthSessions(v ...*PendingAuthSession) *UserCreate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _c.AddPendingAuthSessionIDs(ids...)
+}
+
// Mutation returns the UserMutation object of the builder.
func (_c *UserCreate) Mutation() *UserMutation {
return _c.mutation
@@ -510,6 +584,10 @@ func (_c *UserCreate) defaults() error {
v := user.DefaultTotpEnabled
_c.mutation.SetTotpEnabled(v)
}
+ if _, ok := _c.mutation.SignupSource(); !ok {
+ v := user.DefaultSignupSource
+ _c.mutation.SetSignupSource(v)
+ }
if _, ok := _c.mutation.BalanceNotifyEnabled(); !ok {
v := user.DefaultBalanceNotifyEnabled
_c.mutation.SetBalanceNotifyEnabled(v)
@@ -589,6 +667,14 @@ func (_c *UserCreate) check() error {
if _, ok := _c.mutation.TotpEnabled(); !ok {
return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)}
}
+ if _, ok := _c.mutation.SignupSource(); !ok {
+ return &ValidationError{Name: "signup_source", err: errors.New(`ent: missing required field "User.signup_source"`)}
+ }
+ if v, ok := _c.mutation.SignupSource(); ok {
+ if err := user.SignupSourceValidator(v); err != nil {
+ return &ValidationError{Name: "signup_source", err: fmt.Errorf(`ent: validator failed for field "User.signup_source": %w`, err)}
+ }
+ }
if _, ok := _c.mutation.BalanceNotifyEnabled(); !ok {
return &ValidationError{Name: "balance_notify_enabled", err: errors.New(`ent: missing required field "User.balance_notify_enabled"`)}
}
@@ -684,6 +770,18 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
_spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
_node.TotpEnabledAt = &value
}
+ if value, ok := _c.mutation.SignupSource(); ok {
+ _spec.SetField(user.FieldSignupSource, field.TypeString, value)
+ _node.SignupSource = value
+ }
+ if value, ok := _c.mutation.LastLoginAt(); ok {
+ _spec.SetField(user.FieldLastLoginAt, field.TypeTime, value)
+ _node.LastLoginAt = &value
+ }
+ if value, ok := _c.mutation.LastActiveAt(); ok {
+ _spec.SetField(user.FieldLastActiveAt, field.TypeTime, value)
+ _node.LastActiveAt = &value
+ }
if value, ok := _c.mutation.BalanceNotifyEnabled(); ok {
_spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value)
_node.BalanceNotifyEnabled = value
@@ -868,6 +966,38 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
}
_spec.Edges = append(_spec.Edges, edge)
}
+ if nodes := _c.mutation.AuthIdentitiesIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.AuthIdentitiesTable,
+ Columns: []string{user.AuthIdentitiesColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ if nodes := _c.mutation.PendingAuthSessionsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.PendingAuthSessionsTable,
+ Columns: []string{user.PendingAuthSessionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges = append(_spec.Edges, edge)
+ }
return _node, _spec
}
@@ -1106,6 +1236,54 @@ func (u *UserUpsert) ClearTotpEnabledAt() *UserUpsert {
return u
}
+// SetSignupSource sets the "signup_source" field.
+func (u *UserUpsert) SetSignupSource(v string) *UserUpsert {
+ u.Set(user.FieldSignupSource, v)
+ return u
+}
+
+// UpdateSignupSource sets the "signup_source" field to the value that was provided on create.
+func (u *UserUpsert) UpdateSignupSource() *UserUpsert {
+ u.SetExcluded(user.FieldSignupSource)
+ return u
+}
+
+// SetLastLoginAt sets the "last_login_at" field.
+func (u *UserUpsert) SetLastLoginAt(v time.Time) *UserUpsert {
+ u.Set(user.FieldLastLoginAt, v)
+ return u
+}
+
+// UpdateLastLoginAt sets the "last_login_at" field to the value that was provided on create.
+func (u *UserUpsert) UpdateLastLoginAt() *UserUpsert {
+ u.SetExcluded(user.FieldLastLoginAt)
+ return u
+}
+
+// ClearLastLoginAt clears the value of the "last_login_at" field.
+func (u *UserUpsert) ClearLastLoginAt() *UserUpsert {
+ u.SetNull(user.FieldLastLoginAt)
+ return u
+}
+
+// SetLastActiveAt sets the "last_active_at" field.
+func (u *UserUpsert) SetLastActiveAt(v time.Time) *UserUpsert {
+ u.Set(user.FieldLastActiveAt, v)
+ return u
+}
+
+// UpdateLastActiveAt sets the "last_active_at" field to the value that was provided on create.
+func (u *UserUpsert) UpdateLastActiveAt() *UserUpsert {
+ u.SetExcluded(user.FieldLastActiveAt)
+ return u
+}
+
+// ClearLastActiveAt clears the value of the "last_active_at" field.
+func (u *UserUpsert) ClearLastActiveAt() *UserUpsert {
+ u.SetNull(user.FieldLastActiveAt)
+ return u
+}
+
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (u *UserUpsert) SetBalanceNotifyEnabled(v bool) *UserUpsert {
u.Set(user.FieldBalanceNotifyEnabled, v)
@@ -1446,6 +1624,62 @@ func (u *UserUpsertOne) ClearTotpEnabledAt() *UserUpsertOne {
})
}
+// SetSignupSource sets the "signup_source" field.
+func (u *UserUpsertOne) SetSignupSource(v string) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.SetSignupSource(v)
+ })
+}
+
+// UpdateSignupSource sets the "signup_source" field to the value that was provided on create.
+func (u *UserUpsertOne) UpdateSignupSource() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateSignupSource()
+ })
+}
+
+// SetLastLoginAt sets the "last_login_at" field.
+func (u *UserUpsertOne) SetLastLoginAt(v time.Time) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.SetLastLoginAt(v)
+ })
+}
+
+// UpdateLastLoginAt sets the "last_login_at" field to the value that was provided on create.
+func (u *UserUpsertOne) UpdateLastLoginAt() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateLastLoginAt()
+ })
+}
+
+// ClearLastLoginAt clears the value of the "last_login_at" field.
+func (u *UserUpsertOne) ClearLastLoginAt() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.ClearLastLoginAt()
+ })
+}
+
+// SetLastActiveAt sets the "last_active_at" field.
+func (u *UserUpsertOne) SetLastActiveAt(v time.Time) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.SetLastActiveAt(v)
+ })
+}
+
+// UpdateLastActiveAt sets the "last_active_at" field to the value that was provided on create.
+func (u *UserUpsertOne) UpdateLastActiveAt() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateLastActiveAt()
+ })
+}
+
+// ClearLastActiveAt clears the value of the "last_active_at" field.
+func (u *UserUpsertOne) ClearLastActiveAt() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.ClearLastActiveAt()
+ })
+}
+
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (u *UserUpsertOne) SetBalanceNotifyEnabled(v bool) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
@@ -1965,6 +2199,62 @@ func (u *UserUpsertBulk) ClearTotpEnabledAt() *UserUpsertBulk {
})
}
+// SetSignupSource sets the "signup_source" field.
+func (u *UserUpsertBulk) SetSignupSource(v string) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.SetSignupSource(v)
+ })
+}
+
+// UpdateSignupSource sets the "signup_source" field to the value that was provided on create.
+func (u *UserUpsertBulk) UpdateSignupSource() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateSignupSource()
+ })
+}
+
+// SetLastLoginAt sets the "last_login_at" field.
+func (u *UserUpsertBulk) SetLastLoginAt(v time.Time) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.SetLastLoginAt(v)
+ })
+}
+
+// UpdateLastLoginAt sets the "last_login_at" field to the value that was provided on create.
+func (u *UserUpsertBulk) UpdateLastLoginAt() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateLastLoginAt()
+ })
+}
+
+// ClearLastLoginAt clears the value of the "last_login_at" field.
+func (u *UserUpsertBulk) ClearLastLoginAt() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.ClearLastLoginAt()
+ })
+}
+
+// SetLastActiveAt sets the "last_active_at" field.
+func (u *UserUpsertBulk) SetLastActiveAt(v time.Time) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.SetLastActiveAt(v)
+ })
+}
+
+// UpdateLastActiveAt sets the "last_active_at" field to the value that was provided on create.
+func (u *UserUpsertBulk) UpdateLastActiveAt() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateLastActiveAt()
+ })
+}
+
+// ClearLastActiveAt clears the value of the "last_active_at" field.
+func (u *UserUpsertBulk) ClearLastActiveAt() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.ClearLastActiveAt()
+ })
+}
+
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (u *UserUpsertBulk) SetBalanceNotifyEnabled(v bool) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
diff --git a/backend/ent/user_query.go b/backend/ent/user_query.go
index 113d87ac..f1ee5cfe 100644
--- a/backend/ent/user_query.go
+++ b/backend/ent/user_query.go
@@ -15,8 +15,10 @@ import (
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
@@ -44,6 +46,8 @@ type UserQuery struct {
withAttributeValues *UserAttributeValueQuery
withPromoCodeUsages *PromoCodeUsageQuery
withPaymentOrders *PaymentOrderQuery
+ withAuthIdentities *AuthIdentityQuery
+ withPendingAuthSessions *PendingAuthSessionQuery
withUserAllowedGroups *UserAllowedGroupQuery
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path).
@@ -302,6 +306,50 @@ func (_q *UserQuery) QueryPaymentOrders() *PaymentOrderQuery {
return query
}
+// QueryAuthIdentities chains the current query on the "auth_identities" edge.
+func (_q *UserQuery) QueryAuthIdentities() *AuthIdentityQuery {
+ query := (&AuthIdentityClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(user.Table, user.FieldID, selector),
+ sqlgraph.To(authidentity.Table, authidentity.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, user.AuthIdentitiesTable, user.AuthIdentitiesColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// QueryPendingAuthSessions chains the current query on the "pending_auth_sessions" edge.
+func (_q *UserQuery) QueryPendingAuthSessions() *PendingAuthSessionQuery {
+ query := (&PendingAuthSessionClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(user.Table, user.FieldID, selector),
+ sqlgraph.To(pendingauthsession.Table, pendingauthsession.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, user.PendingAuthSessionsTable, user.PendingAuthSessionsColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
// QueryUserAllowedGroups chains the current query on the "user_allowed_groups" edge.
func (_q *UserQuery) QueryUserAllowedGroups() *UserAllowedGroupQuery {
query := (&UserAllowedGroupClient{config: _q.config}).Query()
@@ -526,6 +574,8 @@ func (_q *UserQuery) Clone() *UserQuery {
withAttributeValues: _q.withAttributeValues.Clone(),
withPromoCodeUsages: _q.withPromoCodeUsages.Clone(),
withPaymentOrders: _q.withPaymentOrders.Clone(),
+ withAuthIdentities: _q.withAuthIdentities.Clone(),
+ withPendingAuthSessions: _q.withPendingAuthSessions.Clone(),
withUserAllowedGroups: _q.withUserAllowedGroups.Clone(),
// clone intermediate query.
sql: _q.sql.Clone(),
@@ -643,6 +693,28 @@ func (_q *UserQuery) WithPaymentOrders(opts ...func(*PaymentOrderQuery)) *UserQu
return _q
}
+// WithAuthIdentities tells the query-builder to eager-load the nodes that are connected to
+// the "auth_identities" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *UserQuery) WithAuthIdentities(opts ...func(*AuthIdentityQuery)) *UserQuery {
+ query := (&AuthIdentityClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withAuthIdentities = query
+ return _q
+}
+
+// WithPendingAuthSessions tells the query-builder to eager-load the nodes that are connected to
+// the "pending_auth_sessions" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *UserQuery) WithPendingAuthSessions(opts ...func(*PendingAuthSessionQuery)) *UserQuery {
+ query := (&PendingAuthSessionClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withPendingAuthSessions = query
+ return _q
+}
+
// WithUserAllowedGroups tells the query-builder to eager-load the nodes that are connected to
// the "user_allowed_groups" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserQuery) WithUserAllowedGroups(opts ...func(*UserAllowedGroupQuery)) *UserQuery {
@@ -732,7 +804,7 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
var (
nodes = []*User{}
_spec = _q.querySpec()
- loadedTypes = [11]bool{
+ loadedTypes = [13]bool{
_q.withAPIKeys != nil,
_q.withRedeemCodes != nil,
_q.withSubscriptions != nil,
@@ -743,6 +815,8 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
_q.withAttributeValues != nil,
_q.withPromoCodeUsages != nil,
_q.withPaymentOrders != nil,
+ _q.withAuthIdentities != nil,
+ _q.withPendingAuthSessions != nil,
_q.withUserAllowedGroups != nil,
}
)
@@ -839,6 +913,22 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
return nil, err
}
}
+ if query := _q.withAuthIdentities; query != nil {
+ if err := _q.loadAuthIdentities(ctx, query, nodes,
+ func(n *User) { n.Edges.AuthIdentities = []*AuthIdentity{} },
+ func(n *User, e *AuthIdentity) { n.Edges.AuthIdentities = append(n.Edges.AuthIdentities, e) }); err != nil {
+ return nil, err
+ }
+ }
+ if query := _q.withPendingAuthSessions; query != nil {
+ if err := _q.loadPendingAuthSessions(ctx, query, nodes,
+ func(n *User) { n.Edges.PendingAuthSessions = []*PendingAuthSession{} },
+ func(n *User, e *PendingAuthSession) {
+ n.Edges.PendingAuthSessions = append(n.Edges.PendingAuthSessions, e)
+ }); err != nil {
+ return nil, err
+ }
+ }
if query := _q.withUserAllowedGroups; query != nil {
if err := _q.loadUserAllowedGroups(ctx, query, nodes,
func(n *User) { n.Edges.UserAllowedGroups = []*UserAllowedGroup{} },
@@ -1186,6 +1276,69 @@ func (_q *UserQuery) loadPaymentOrders(ctx context.Context, query *PaymentOrderQ
}
return nil
}
+func (_q *UserQuery) loadAuthIdentities(ctx context.Context, query *AuthIdentityQuery, nodes []*User, init func(*User), assign func(*User, *AuthIdentity)) error {
+ fks := make([]driver.Value, 0, len(nodes))
+ nodeids := make(map[int64]*User)
+ for i := range nodes {
+ fks = append(fks, nodes[i].ID)
+ nodeids[nodes[i].ID] = nodes[i]
+ if init != nil {
+ init(nodes[i])
+ }
+ }
+ if len(query.ctx.Fields) > 0 {
+ query.ctx.AppendFieldOnce(authidentity.FieldUserID)
+ }
+ query.Where(predicate.AuthIdentity(func(s *sql.Selector) {
+ s.Where(sql.InValues(s.C(user.AuthIdentitiesColumn), fks...))
+ }))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ fk := n.UserID
+ node, ok := nodeids[fk]
+ if !ok {
+ return fmt.Errorf(`unexpected referenced foreign-key "user_id" returned %v for node %v`, fk, n.ID)
+ }
+ assign(node, n)
+ }
+ return nil
+}
+func (_q *UserQuery) loadPendingAuthSessions(ctx context.Context, query *PendingAuthSessionQuery, nodes []*User, init func(*User), assign func(*User, *PendingAuthSession)) error {
+ fks := make([]driver.Value, 0, len(nodes))
+ nodeids := make(map[int64]*User)
+ for i := range nodes {
+ fks = append(fks, nodes[i].ID)
+ nodeids[nodes[i].ID] = nodes[i]
+ if init != nil {
+ init(nodes[i])
+ }
+ }
+ if len(query.ctx.Fields) > 0 {
+ query.ctx.AppendFieldOnce(pendingauthsession.FieldTargetUserID)
+ }
+ query.Where(predicate.PendingAuthSession(func(s *sql.Selector) {
+ s.Where(sql.InValues(s.C(user.PendingAuthSessionsColumn), fks...))
+ }))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ fk := n.TargetUserID
+ if fk == nil {
+ return fmt.Errorf(`foreign-key "target_user_id" is nil for node %v`, n.ID)
+ }
+ node, ok := nodeids[*fk]
+ if !ok {
+ return fmt.Errorf(`unexpected referenced foreign-key "target_user_id" returned %v for node %v`, *fk, n.ID)
+ }
+ assign(node, n)
+ }
+ return nil
+}
func (_q *UserQuery) loadUserAllowedGroups(ctx context.Context, query *UserAllowedGroupQuery, nodes []*User, init func(*User), assign func(*User, *UserAllowedGroup)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*User)
diff --git a/backend/ent/user_update.go b/backend/ent/user_update.go
index 6b355247..677eeb6b 100644
--- a/backend/ent/user_update.go
+++ b/backend/ent/user_update.go
@@ -13,8 +13,10 @@ import (
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
@@ -243,6 +245,60 @@ func (_u *UserUpdate) ClearTotpEnabledAt() *UserUpdate {
return _u
}
+// SetSignupSource sets the "signup_source" field.
+func (_u *UserUpdate) SetSignupSource(v string) *UserUpdate {
+ _u.mutation.SetSignupSource(v)
+ return _u
+}
+
+// SetNillableSignupSource sets the "signup_source" field if the given value is not nil.
+func (_u *UserUpdate) SetNillableSignupSource(v *string) *UserUpdate {
+ if v != nil {
+ _u.SetSignupSource(*v)
+ }
+ return _u
+}
+
+// SetLastLoginAt sets the "last_login_at" field.
+func (_u *UserUpdate) SetLastLoginAt(v time.Time) *UserUpdate {
+ _u.mutation.SetLastLoginAt(v)
+ return _u
+}
+
+// SetNillableLastLoginAt sets the "last_login_at" field if the given value is not nil.
+func (_u *UserUpdate) SetNillableLastLoginAt(v *time.Time) *UserUpdate {
+ if v != nil {
+ _u.SetLastLoginAt(*v)
+ }
+ return _u
+}
+
+// ClearLastLoginAt clears the value of the "last_login_at" field.
+func (_u *UserUpdate) ClearLastLoginAt() *UserUpdate {
+ _u.mutation.ClearLastLoginAt()
+ return _u
+}
+
+// SetLastActiveAt sets the "last_active_at" field.
+func (_u *UserUpdate) SetLastActiveAt(v time.Time) *UserUpdate {
+ _u.mutation.SetLastActiveAt(v)
+ return _u
+}
+
+// SetNillableLastActiveAt sets the "last_active_at" field if the given value is not nil.
+func (_u *UserUpdate) SetNillableLastActiveAt(v *time.Time) *UserUpdate {
+ if v != nil {
+ _u.SetLastActiveAt(*v)
+ }
+ return _u
+}
+
+// ClearLastActiveAt clears the value of the "last_active_at" field.
+func (_u *UserUpdate) ClearLastActiveAt() *UserUpdate {
+ _u.mutation.ClearLastActiveAt()
+ return _u
+}
+
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (_u *UserUpdate) SetBalanceNotifyEnabled(v bool) *UserUpdate {
_u.mutation.SetBalanceNotifyEnabled(v)
@@ -483,6 +539,36 @@ func (_u *UserUpdate) AddPaymentOrders(v ...*PaymentOrder) *UserUpdate {
return _u.AddPaymentOrderIDs(ids...)
}
+// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by IDs.
+func (_u *UserUpdate) AddAuthIdentityIDs(ids ...int64) *UserUpdate {
+ _u.mutation.AddAuthIdentityIDs(ids...)
+ return _u
+}
+
+// AddAuthIdentities adds the "auth_identities" edges to the AuthIdentity entity.
+func (_u *UserUpdate) AddAuthIdentities(v ...*AuthIdentity) *UserUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddAuthIdentityIDs(ids...)
+}
+
+// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs.
+func (_u *UserUpdate) AddPendingAuthSessionIDs(ids ...int64) *UserUpdate {
+ _u.mutation.AddPendingAuthSessionIDs(ids...)
+ return _u
+}
+
+// AddPendingAuthSessions adds the "pending_auth_sessions" edges to the PendingAuthSession entity.
+func (_u *UserUpdate) AddPendingAuthSessions(v ...*PendingAuthSession) *UserUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddPendingAuthSessionIDs(ids...)
+}
+
// Mutation returns the UserMutation object of the builder.
func (_u *UserUpdate) Mutation() *UserMutation {
return _u.mutation
@@ -698,6 +784,48 @@ func (_u *UserUpdate) RemovePaymentOrders(v ...*PaymentOrder) *UserUpdate {
return _u.RemovePaymentOrderIDs(ids...)
}
+// ClearAuthIdentities clears all "auth_identities" edges to the AuthIdentity entity.
+func (_u *UserUpdate) ClearAuthIdentities() *UserUpdate {
+ _u.mutation.ClearAuthIdentities()
+ return _u
+}
+
+// RemoveAuthIdentityIDs removes the "auth_identities" edge to AuthIdentity entities by IDs.
+func (_u *UserUpdate) RemoveAuthIdentityIDs(ids ...int64) *UserUpdate {
+ _u.mutation.RemoveAuthIdentityIDs(ids...)
+ return _u
+}
+
+// RemoveAuthIdentities removes "auth_identities" edges to AuthIdentity entities.
+func (_u *UserUpdate) RemoveAuthIdentities(v ...*AuthIdentity) *UserUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveAuthIdentityIDs(ids...)
+}
+
+// ClearPendingAuthSessions clears all "pending_auth_sessions" edges to the PendingAuthSession entity.
+func (_u *UserUpdate) ClearPendingAuthSessions() *UserUpdate {
+ _u.mutation.ClearPendingAuthSessions()
+ return _u
+}
+
+// RemovePendingAuthSessionIDs removes the "pending_auth_sessions" edge to PendingAuthSession entities by IDs.
+func (_u *UserUpdate) RemovePendingAuthSessionIDs(ids ...int64) *UserUpdate {
+ _u.mutation.RemovePendingAuthSessionIDs(ids...)
+ return _u
+}
+
+// RemovePendingAuthSessions removes "pending_auth_sessions" edges to PendingAuthSession entities.
+func (_u *UserUpdate) RemovePendingAuthSessions(v ...*PendingAuthSession) *UserUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemovePendingAuthSessionIDs(ids...)
+}
+
// Save executes the query and returns the number of nodes affected by the update operation.
func (_u *UserUpdate) Save(ctx context.Context) (int, error) {
if err := _u.defaults(); err != nil {
@@ -767,6 +895,11 @@ func (_u *UserUpdate) check() error {
return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)}
}
}
+ if v, ok := _u.mutation.SignupSource(); ok {
+ if err := user.SignupSourceValidator(v); err != nil {
+ return &ValidationError{Name: "signup_source", err: fmt.Errorf(`ent: validator failed for field "User.signup_source": %w`, err)}
+ }
+ }
return nil
}
@@ -836,6 +969,21 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.TotpEnabledAtCleared() {
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
}
+ if value, ok := _u.mutation.SignupSource(); ok {
+ _spec.SetField(user.FieldSignupSource, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.LastLoginAt(); ok {
+ _spec.SetField(user.FieldLastLoginAt, field.TypeTime, value)
+ }
+ if _u.mutation.LastLoginAtCleared() {
+ _spec.ClearField(user.FieldLastLoginAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.LastActiveAt(); ok {
+ _spec.SetField(user.FieldLastActiveAt, field.TypeTime, value)
+ }
+ if _u.mutation.LastActiveAtCleared() {
+ _spec.ClearField(user.FieldLastActiveAt, field.TypeTime)
+ }
if value, ok := _u.mutation.BalanceNotifyEnabled(); ok {
_spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value)
}
@@ -1322,6 +1470,96 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
+ if _u.mutation.AuthIdentitiesCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.AuthIdentitiesTable,
+ Columns: []string{user.AuthIdentitiesColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedAuthIdentitiesIDs(); len(nodes) > 0 && !_u.mutation.AuthIdentitiesCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.AuthIdentitiesTable,
+ Columns: []string{user.AuthIdentitiesColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.AuthIdentitiesIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.AuthIdentitiesTable,
+ Columns: []string{user.AuthIdentitiesColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.PendingAuthSessionsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.PendingAuthSessionsTable,
+ Columns: []string{user.PendingAuthSessionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedPendingAuthSessionsIDs(); len(nodes) > 0 && !_u.mutation.PendingAuthSessionsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.PendingAuthSessionsTable,
+ Columns: []string{user.PendingAuthSessionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.PendingAuthSessionsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.PendingAuthSessionsTable,
+ Columns: []string{user.PendingAuthSessionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{user.Label}
@@ -1548,6 +1786,60 @@ func (_u *UserUpdateOne) ClearTotpEnabledAt() *UserUpdateOne {
return _u
}
+// SetSignupSource sets the "signup_source" field.
+func (_u *UserUpdateOne) SetSignupSource(v string) *UserUpdateOne {
+ _u.mutation.SetSignupSource(v)
+ return _u
+}
+
+// SetNillableSignupSource sets the "signup_source" field if the given value is not nil.
+func (_u *UserUpdateOne) SetNillableSignupSource(v *string) *UserUpdateOne {
+ if v != nil {
+ _u.SetSignupSource(*v)
+ }
+ return _u
+}
+
+// SetLastLoginAt sets the "last_login_at" field.
+func (_u *UserUpdateOne) SetLastLoginAt(v time.Time) *UserUpdateOne {
+ _u.mutation.SetLastLoginAt(v)
+ return _u
+}
+
+// SetNillableLastLoginAt sets the "last_login_at" field if the given value is not nil.
+func (_u *UserUpdateOne) SetNillableLastLoginAt(v *time.Time) *UserUpdateOne {
+ if v != nil {
+ _u.SetLastLoginAt(*v)
+ }
+ return _u
+}
+
+// ClearLastLoginAt clears the value of the "last_login_at" field.
+func (_u *UserUpdateOne) ClearLastLoginAt() *UserUpdateOne {
+ _u.mutation.ClearLastLoginAt()
+ return _u
+}
+
+// SetLastActiveAt sets the "last_active_at" field.
+func (_u *UserUpdateOne) SetLastActiveAt(v time.Time) *UserUpdateOne {
+ _u.mutation.SetLastActiveAt(v)
+ return _u
+}
+
+// SetNillableLastActiveAt sets the "last_active_at" field if the given value is not nil.
+func (_u *UserUpdateOne) SetNillableLastActiveAt(v *time.Time) *UserUpdateOne {
+ if v != nil {
+ _u.SetLastActiveAt(*v)
+ }
+ return _u
+}
+
+// ClearLastActiveAt clears the value of the "last_active_at" field.
+func (_u *UserUpdateOne) ClearLastActiveAt() *UserUpdateOne {
+ _u.mutation.ClearLastActiveAt()
+ return _u
+}
+
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (_u *UserUpdateOne) SetBalanceNotifyEnabled(v bool) *UserUpdateOne {
_u.mutation.SetBalanceNotifyEnabled(v)
@@ -1788,6 +2080,36 @@ func (_u *UserUpdateOne) AddPaymentOrders(v ...*PaymentOrder) *UserUpdateOne {
return _u.AddPaymentOrderIDs(ids...)
}
+// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by IDs.
+func (_u *UserUpdateOne) AddAuthIdentityIDs(ids ...int64) *UserUpdateOne {
+ _u.mutation.AddAuthIdentityIDs(ids...)
+ return _u
+}
+
+// AddAuthIdentities adds the "auth_identities" edges to the AuthIdentity entity.
+func (_u *UserUpdateOne) AddAuthIdentities(v ...*AuthIdentity) *UserUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddAuthIdentityIDs(ids...)
+}
+
+// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs.
+func (_u *UserUpdateOne) AddPendingAuthSessionIDs(ids ...int64) *UserUpdateOne {
+ _u.mutation.AddPendingAuthSessionIDs(ids...)
+ return _u
+}
+
+// AddPendingAuthSessions adds the "pending_auth_sessions" edges to the PendingAuthSession entity.
+func (_u *UserUpdateOne) AddPendingAuthSessions(v ...*PendingAuthSession) *UserUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddPendingAuthSessionIDs(ids...)
+}
+
// Mutation returns the UserMutation object of the builder.
func (_u *UserUpdateOne) Mutation() *UserMutation {
return _u.mutation
@@ -2003,6 +2325,48 @@ func (_u *UserUpdateOne) RemovePaymentOrders(v ...*PaymentOrder) *UserUpdateOne
return _u.RemovePaymentOrderIDs(ids...)
}
+// ClearAuthIdentities clears all "auth_identities" edges to the AuthIdentity entity.
+func (_u *UserUpdateOne) ClearAuthIdentities() *UserUpdateOne {
+ _u.mutation.ClearAuthIdentities()
+ return _u
+}
+
+// RemoveAuthIdentityIDs removes the "auth_identities" edge to AuthIdentity entities by IDs.
+func (_u *UserUpdateOne) RemoveAuthIdentityIDs(ids ...int64) *UserUpdateOne {
+ _u.mutation.RemoveAuthIdentityIDs(ids...)
+ return _u
+}
+
+// RemoveAuthIdentities removes "auth_identities" edges to AuthIdentity entities.
+func (_u *UserUpdateOne) RemoveAuthIdentities(v ...*AuthIdentity) *UserUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveAuthIdentityIDs(ids...)
+}
+
+// ClearPendingAuthSessions clears all "pending_auth_sessions" edges to the PendingAuthSession entity.
+func (_u *UserUpdateOne) ClearPendingAuthSessions() *UserUpdateOne {
+ _u.mutation.ClearPendingAuthSessions()
+ return _u
+}
+
+// RemovePendingAuthSessionIDs removes the "pending_auth_sessions" edge to PendingAuthSession entities by IDs.
+func (_u *UserUpdateOne) RemovePendingAuthSessionIDs(ids ...int64) *UserUpdateOne {
+ _u.mutation.RemovePendingAuthSessionIDs(ids...)
+ return _u
+}
+
+// RemovePendingAuthSessions removes "pending_auth_sessions" edges to PendingAuthSession entities.
+func (_u *UserUpdateOne) RemovePendingAuthSessions(v ...*PendingAuthSession) *UserUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemovePendingAuthSessionIDs(ids...)
+}
+
// Where appends a list predicates to the UserUpdate builder.
func (_u *UserUpdateOne) Where(ps ...predicate.User) *UserUpdateOne {
_u.mutation.Where(ps...)
@@ -2085,6 +2449,11 @@ func (_u *UserUpdateOne) check() error {
return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)}
}
}
+ if v, ok := _u.mutation.SignupSource(); ok {
+ if err := user.SignupSourceValidator(v); err != nil {
+ return &ValidationError{Name: "signup_source", err: fmt.Errorf(`ent: validator failed for field "User.signup_source": %w`, err)}
+ }
+ }
return nil
}
@@ -2171,6 +2540,21 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
if _u.mutation.TotpEnabledAtCleared() {
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
}
+ if value, ok := _u.mutation.SignupSource(); ok {
+ _spec.SetField(user.FieldSignupSource, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.LastLoginAt(); ok {
+ _spec.SetField(user.FieldLastLoginAt, field.TypeTime, value)
+ }
+ if _u.mutation.LastLoginAtCleared() {
+ _spec.ClearField(user.FieldLastLoginAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.LastActiveAt(); ok {
+ _spec.SetField(user.FieldLastActiveAt, field.TypeTime, value)
+ }
+ if _u.mutation.LastActiveAtCleared() {
+ _spec.ClearField(user.FieldLastActiveAt, field.TypeTime)
+ }
if value, ok := _u.mutation.BalanceNotifyEnabled(); ok {
_spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value)
}
@@ -2657,6 +3041,96 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
+ if _u.mutation.AuthIdentitiesCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.AuthIdentitiesTable,
+ Columns: []string{user.AuthIdentitiesColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedAuthIdentitiesIDs(); len(nodes) > 0 && !_u.mutation.AuthIdentitiesCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.AuthIdentitiesTable,
+ Columns: []string{user.AuthIdentitiesColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.AuthIdentitiesIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.AuthIdentitiesTable,
+ Columns: []string{user.AuthIdentitiesColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.PendingAuthSessionsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.PendingAuthSessionsTable,
+ Columns: []string{user.PendingAuthSessionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedPendingAuthSessionsIDs(); len(nodes) > 0 && !_u.mutation.PendingAuthSessionsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.PendingAuthSessionsTable,
+ Columns: []string{user.PendingAuthSessionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.PendingAuthSessionsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.PendingAuthSessionsTable,
+ Columns: []string{user.PendingAuthSessionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
_node = &User{config: _u.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues
diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go
index dd9a4e58..6136e9ea 100644
--- a/backend/internal/config/config.go
+++ b/backend/internal/config/config.go
@@ -1608,6 +1608,9 @@ func (c *Config) Validate() error {
return fmt.Errorf("security.csp.policy is required when CSP is enabled")
}
if c.LinuxDo.Enabled {
+ if !c.LinuxDo.UsePKCE {
+ return fmt.Errorf("linuxdo_connect.use_pkce must be true when linuxdo_connect.enabled=true")
+ }
if strings.TrimSpace(c.LinuxDo.ClientID) == "" {
return fmt.Errorf("linuxdo_connect.client_id is required when linuxdo_connect.enabled=true")
}
@@ -1629,9 +1632,6 @@ func (c *Config) Validate() error {
default:
return fmt.Errorf("linuxdo_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none")
}
- if method == "none" && !c.LinuxDo.UsePKCE {
- return fmt.Errorf("linuxdo_connect.use_pkce must be true when linuxdo_connect.token_auth_method=none")
- }
if (method == "" || method == "client_secret_post" || method == "client_secret_basic") &&
strings.TrimSpace(c.LinuxDo.ClientSecret) == "" {
return fmt.Errorf("linuxdo_connect.client_secret is required when linuxdo_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic")
@@ -1663,6 +1663,12 @@ func (c *Config) Validate() error {
warnIfInsecureURL("linuxdo_connect.frontend_redirect_url", c.LinuxDo.FrontendRedirectURL)
}
if c.OIDC.Enabled {
+ if !c.OIDC.UsePKCE {
+ return fmt.Errorf("oidc_connect.use_pkce must be true when oidc_connect.enabled=true")
+ }
+ if !c.OIDC.ValidateIDToken {
+ return fmt.Errorf("oidc_connect.validate_id_token must be true when oidc_connect.enabled=true")
+ }
if strings.TrimSpace(c.OIDC.ClientID) == "" {
return fmt.Errorf("oidc_connect.client_id is required when oidc_connect.enabled=true")
}
@@ -1685,9 +1691,6 @@ func (c *Config) Validate() error {
default:
return fmt.Errorf("oidc_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none")
}
- if method == "none" && !c.OIDC.UsePKCE {
- return fmt.Errorf("oidc_connect.use_pkce must be true when oidc_connect.token_auth_method=none")
- }
if (method == "" || method == "client_secret_post" || method == "client_secret_basic") &&
strings.TrimSpace(c.OIDC.ClientSecret) == "" {
return fmt.Errorf("oidc_connect.client_secret is required when oidc_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic")
diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go
index bec0f126..fe5c7928 100644
--- a/backend/internal/handler/admin/setting_handler.go
+++ b/backend/internal/handler/admin/setting_handler.go
@@ -73,6 +73,11 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
+ authSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
// Check if ops monitoring is enabled (respects config.ops.enabled)
opsEnabled := h.opsService != nil && h.opsService.IsMonitoringEnabled(c.Request.Context())
@@ -93,7 +98,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
paymentCfg = &service.PaymentConfig{}
}
- response.Success(c, dto.SystemSettings{
+ payload := dto.SystemSettings{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
@@ -200,7 +205,8 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
PaymentCancelRateLimitWindow: paymentCfg.CancelRateLimitWindow,
PaymentCancelRateLimitUnit: paymentCfg.CancelRateLimitUnit,
PaymentCancelRateLimitMode: paymentCfg.CancelRateLimitMode,
- })
+ }
+ response.Success(c, systemSettingsResponseData(payload, authSourceDefaults))
}
// UpdateSettingsRequest 更新设置请求
@@ -276,9 +282,30 @@ type UpdateSettingsRequest struct {
CustomEndpoints *[]dto.CustomEndpoint `json:"custom_endpoints"`
// 默认配置
- DefaultConcurrency int `json:"default_concurrency"`
- DefaultBalance float64 `json:"default_balance"`
- DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
+ DefaultConcurrency int `json:"default_concurrency"`
+ DefaultBalance float64 `json:"default_balance"`
+ DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
+ AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"`
+ AuthSourceDefaultEmailConcurrency *int `json:"auth_source_default_email_concurrency"`
+ AuthSourceDefaultEmailSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_email_subscriptions"`
+ AuthSourceDefaultEmailGrantOnSignup *bool `json:"auth_source_default_email_grant_on_signup"`
+ AuthSourceDefaultEmailGrantOnFirstBind *bool `json:"auth_source_default_email_grant_on_first_bind"`
+ AuthSourceDefaultLinuxDoBalance *float64 `json:"auth_source_default_linuxdo_balance"`
+ AuthSourceDefaultLinuxDoConcurrency *int `json:"auth_source_default_linuxdo_concurrency"`
+ AuthSourceDefaultLinuxDoSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_linuxdo_subscriptions"`
+ AuthSourceDefaultLinuxDoGrantOnSignup *bool `json:"auth_source_default_linuxdo_grant_on_signup"`
+ AuthSourceDefaultLinuxDoGrantOnFirstBind *bool `json:"auth_source_default_linuxdo_grant_on_first_bind"`
+ AuthSourceDefaultOIDCBalance *float64 `json:"auth_source_default_oidc_balance"`
+ AuthSourceDefaultOIDCConcurrency *int `json:"auth_source_default_oidc_concurrency"`
+ AuthSourceDefaultOIDCSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_oidc_subscriptions"`
+ AuthSourceDefaultOIDCGrantOnSignup *bool `json:"auth_source_default_oidc_grant_on_signup"`
+ AuthSourceDefaultOIDCGrantOnFirstBind *bool `json:"auth_source_default_oidc_grant_on_first_bind"`
+ AuthSourceDefaultWeChatBalance *float64 `json:"auth_source_default_wechat_balance"`
+ AuthSourceDefaultWeChatConcurrency *int `json:"auth_source_default_wechat_concurrency"`
+ AuthSourceDefaultWeChatSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_wechat_subscriptions"`
+ AuthSourceDefaultWeChatGrantOnSignup *bool `json:"auth_source_default_wechat_grant_on_signup"`
+ AuthSourceDefaultWeChatGrantOnFirstBind *bool `json:"auth_source_default_wechat_grant_on_first_bind"`
+ ForceEmailOnThirdPartySignup *bool `json:"force_email_on_third_party_signup"`
// Model fallback configuration
EnableModelFallback bool `json:"enable_model_fallback"`
@@ -357,6 +384,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
+ previousAuthSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
// 验证参数
if req.DefaultConcurrency < 1 {
@@ -381,6 +413,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
req.SMTPPort = 587
}
req.DefaultSubscriptions = normalizeDefaultSubscriptions(req.DefaultSubscriptions)
+ req.AuthSourceDefaultEmailSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultEmailSubscriptions)
+ req.AuthSourceDefaultLinuxDoSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultLinuxDoSubscriptions)
+ req.AuthSourceDefaultOIDCSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultOIDCSubscriptions)
+ req.AuthSourceDefaultWeChatSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultWeChatSubscriptions)
// SMTP 配置保护:如果请求中 smtp_host 为空但数据库中已有配置,则保留已有 SMTP 配置
// 防止前端加载设置失败时空表单覆盖已保存的 SMTP 配置
@@ -538,25 +574,27 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.BadRequest(c, "OIDC scopes must contain openid")
return
}
+ if !req.OIDCConnectUsePKCE {
+ response.BadRequest(c, "OIDC PKCE must be enabled")
+ return
+ }
+ if !req.OIDCConnectValidateIDToken {
+ response.BadRequest(c, "OIDC ID Token validation must be enabled")
+ return
+ }
switch req.OIDCConnectTokenAuthMethod {
case "", "client_secret_post", "client_secret_basic", "none":
default:
response.BadRequest(c, "OIDC Token Auth Method must be one of client_secret_post/client_secret_basic/none")
return
}
- if req.OIDCConnectTokenAuthMethod == "none" && !req.OIDCConnectUsePKCE {
- response.BadRequest(c, "OIDC PKCE must be enabled when token_auth_method=none")
- return
- }
if req.OIDCConnectClockSkewSeconds < 0 || req.OIDCConnectClockSkewSeconds > 600 {
response.BadRequest(c, "OIDC clock skew seconds must be between 0 and 600")
return
}
- if req.OIDCConnectValidateIDToken {
- if req.OIDCConnectAllowedSigningAlgs == "" {
- response.BadRequest(c, "OIDC Allowed Signing Algs is required when validate_id_token=true")
- return
- }
+ if req.OIDCConnectAllowedSigningAlgs == "" {
+ response.BadRequest(c, "OIDC Allowed Signing Algs is required when validate_id_token=true")
+ return
}
if req.OIDCConnectJWKSURL != "" {
if err := config.ValidateAbsoluteHTTPURL(req.OIDCConnectJWKSURL); err != nil {
@@ -933,6 +971,41 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
+ authSourceDefaults := &service.AuthSourceDefaultSettings{
+ Email: service.ProviderDefaultGrantSettings{
+ Balance: float64ValueOrDefault(req.AuthSourceDefaultEmailBalance, previousAuthSourceDefaults.Email.Balance),
+ Concurrency: intValueOrDefault(req.AuthSourceDefaultEmailConcurrency, previousAuthSourceDefaults.Email.Concurrency),
+ Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultEmailSubscriptions, previousAuthSourceDefaults.Email.Subscriptions),
+ GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultEmailGrantOnSignup, previousAuthSourceDefaults.Email.GrantOnSignup),
+ GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultEmailGrantOnFirstBind, previousAuthSourceDefaults.Email.GrantOnFirstBind),
+ },
+ LinuxDo: service.ProviderDefaultGrantSettings{
+ Balance: float64ValueOrDefault(req.AuthSourceDefaultLinuxDoBalance, previousAuthSourceDefaults.LinuxDo.Balance),
+ Concurrency: intValueOrDefault(req.AuthSourceDefaultLinuxDoConcurrency, previousAuthSourceDefaults.LinuxDo.Concurrency),
+ Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultLinuxDoSubscriptions, previousAuthSourceDefaults.LinuxDo.Subscriptions),
+ GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultLinuxDoGrantOnSignup, previousAuthSourceDefaults.LinuxDo.GrantOnSignup),
+ GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultLinuxDoGrantOnFirstBind, previousAuthSourceDefaults.LinuxDo.GrantOnFirstBind),
+ },
+ OIDC: service.ProviderDefaultGrantSettings{
+ Balance: float64ValueOrDefault(req.AuthSourceDefaultOIDCBalance, previousAuthSourceDefaults.OIDC.Balance),
+ Concurrency: intValueOrDefault(req.AuthSourceDefaultOIDCConcurrency, previousAuthSourceDefaults.OIDC.Concurrency),
+ Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultOIDCSubscriptions, previousAuthSourceDefaults.OIDC.Subscriptions),
+ GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultOIDCGrantOnSignup, previousAuthSourceDefaults.OIDC.GrantOnSignup),
+ GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultOIDCGrantOnFirstBind, previousAuthSourceDefaults.OIDC.GrantOnFirstBind),
+ },
+ WeChat: service.ProviderDefaultGrantSettings{
+ Balance: float64ValueOrDefault(req.AuthSourceDefaultWeChatBalance, previousAuthSourceDefaults.WeChat.Balance),
+ Concurrency: intValueOrDefault(req.AuthSourceDefaultWeChatConcurrency, previousAuthSourceDefaults.WeChat.Concurrency),
+ Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultWeChatSubscriptions, previousAuthSourceDefaults.WeChat.Subscriptions),
+ GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnSignup, previousAuthSourceDefaults.WeChat.GrantOnSignup),
+ GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnFirstBind, previousAuthSourceDefaults.WeChat.GrantOnFirstBind),
+ },
+ ForceEmailOnThirdPartySignup: boolValueOrDefault(req.ForceEmailOnThirdPartySignup, previousAuthSourceDefaults.ForceEmailOnThirdPartySignup),
+ }
+ if err := h.settingService.UpdateAuthSourceDefaultSettings(c.Request.Context(), authSourceDefaults); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
// Update payment configuration (integrated into system settings).
// Skip if no payment fields were provided (prevents accidental wipe).
@@ -977,6 +1050,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
+ updatedAuthSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
updatedDefaultSubscriptions := make([]dto.DefaultSubscriptionSetting, 0, len(updatedSettings.DefaultSubscriptions))
for _, sub := range updatedSettings.DefaultSubscriptions {
updatedDefaultSubscriptions = append(updatedDefaultSubscriptions, dto.DefaultSubscriptionSetting{
@@ -994,7 +1072,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
updatedPaymentCfg = &service.PaymentConfig{}
}
- response.Success(c, dto.SystemSettings{
+ payload := dto.SystemSettings{
RegistrationEnabled: updatedSettings.RegistrationEnabled,
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist,
@@ -1100,7 +1178,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
PaymentCancelRateLimitWindow: updatedPaymentCfg.CancelRateLimitWindow,
PaymentCancelRateLimitUnit: updatedPaymentCfg.CancelRateLimitUnit,
PaymentCancelRateLimitMode: updatedPaymentCfg.CancelRateLimitMode,
- })
+ }
+ response.Success(c, systemSettingsResponseData(payload, updatedAuthSourceDefaults))
}
// hasPaymentFields returns true if any payment-related field was explicitly provided.
@@ -1412,6 +1491,84 @@ func normalizeDefaultSubscriptions(input []dto.DefaultSubscriptionSetting) []dto
return normalized
}
+func normalizeOptionalDefaultSubscriptions(input *[]dto.DefaultSubscriptionSetting) *[]dto.DefaultSubscriptionSetting {
+ if input == nil {
+ return nil
+ }
+ normalized := normalizeDefaultSubscriptions(*input)
+ return &normalized
+}
+
+func float64ValueOrDefault(value *float64, fallback float64) float64 {
+ if value == nil {
+ return fallback
+ }
+ return *value
+}
+
+func intValueOrDefault(value *int, fallback int) int {
+ if value == nil {
+ return fallback
+ }
+ return *value
+}
+
+func boolValueOrDefault(value *bool, fallback bool) bool {
+ if value == nil {
+ return fallback
+ }
+ return *value
+}
+
+func defaultSubscriptionsValueOrDefault(input *[]dto.DefaultSubscriptionSetting, fallback []service.DefaultSubscriptionSetting) []service.DefaultSubscriptionSetting {
+ if input == nil {
+ return fallback
+ }
+ result := make([]service.DefaultSubscriptionSetting, 0, len(*input))
+ for _, item := range *input {
+ result = append(result, service.DefaultSubscriptionSetting{
+ GroupID: item.GroupID,
+ ValidityDays: item.ValidityDays,
+ })
+ }
+ return result
+}
+
+func systemSettingsResponseData(settings dto.SystemSettings, authSourceDefaults *service.AuthSourceDefaultSettings) map[string]any {
+ data := make(map[string]any)
+ raw, err := json.Marshal(settings)
+ if err == nil {
+ _ = json.Unmarshal(raw, &data)
+ }
+ if authSourceDefaults == nil {
+ authSourceDefaults = &service.AuthSourceDefaultSettings{}
+ }
+
+ data["auth_source_default_email_balance"] = authSourceDefaults.Email.Balance
+ data["auth_source_default_email_concurrency"] = authSourceDefaults.Email.Concurrency
+ data["auth_source_default_email_subscriptions"] = authSourceDefaults.Email.Subscriptions
+ data["auth_source_default_email_grant_on_signup"] = authSourceDefaults.Email.GrantOnSignup
+ data["auth_source_default_email_grant_on_first_bind"] = authSourceDefaults.Email.GrantOnFirstBind
+ data["auth_source_default_linuxdo_balance"] = authSourceDefaults.LinuxDo.Balance
+ data["auth_source_default_linuxdo_concurrency"] = authSourceDefaults.LinuxDo.Concurrency
+ data["auth_source_default_linuxdo_subscriptions"] = authSourceDefaults.LinuxDo.Subscriptions
+ data["auth_source_default_linuxdo_grant_on_signup"] = authSourceDefaults.LinuxDo.GrantOnSignup
+ data["auth_source_default_linuxdo_grant_on_first_bind"] = authSourceDefaults.LinuxDo.GrantOnFirstBind
+ data["auth_source_default_oidc_balance"] = authSourceDefaults.OIDC.Balance
+ data["auth_source_default_oidc_concurrency"] = authSourceDefaults.OIDC.Concurrency
+ data["auth_source_default_oidc_subscriptions"] = authSourceDefaults.OIDC.Subscriptions
+ data["auth_source_default_oidc_grant_on_signup"] = authSourceDefaults.OIDC.GrantOnSignup
+ data["auth_source_default_oidc_grant_on_first_bind"] = authSourceDefaults.OIDC.GrantOnFirstBind
+ data["auth_source_default_wechat_balance"] = authSourceDefaults.WeChat.Balance
+ data["auth_source_default_wechat_concurrency"] = authSourceDefaults.WeChat.Concurrency
+ data["auth_source_default_wechat_subscriptions"] = authSourceDefaults.WeChat.Subscriptions
+ data["auth_source_default_wechat_grant_on_signup"] = authSourceDefaults.WeChat.GrantOnSignup
+ data["auth_source_default_wechat_grant_on_first_bind"] = authSourceDefaults.WeChat.GrantOnFirstBind
+ data["force_email_on_third_party_signup"] = authSourceDefaults.ForceEmailOnThirdPartySignup
+
+ return data
+}
+
func equalStringSlice(a, b []string) bool {
if len(a) != len(b) {
return false
diff --git a/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go
new file mode 100644
index 00000000..b26fa447
--- /dev/null
+++ b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go
@@ -0,0 +1,149 @@
+package admin
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+type settingHandlerRepoStub struct {
+ values map[string]string
+ lastUpdates map[string]string
+}
+
+func (s *settingHandlerRepoStub) Get(ctx context.Context, key string) (*service.Setting, error) {
+ panic("unexpected Get call")
+}
+
+func (s *settingHandlerRepoStub) GetValue(ctx context.Context, key string) (string, error) {
+ panic("unexpected GetValue call")
+}
+
+func (s *settingHandlerRepoStub) Set(ctx context.Context, key, value string) error {
+ panic("unexpected Set call")
+}
+
+func (s *settingHandlerRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
+ out := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if value, ok := s.values[key]; ok {
+ out[key] = value
+ }
+ }
+ return out, nil
+}
+
+func (s *settingHandlerRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
+ s.lastUpdates = make(map[string]string, len(settings))
+ for key, value := range settings {
+ s.lastUpdates[key] = value
+ if s.values == nil {
+ s.values = map[string]string{}
+ }
+ s.values[key] = value
+ }
+ return nil
+}
+
+func (s *settingHandlerRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
+ out := make(map[string]string, len(s.values))
+ for key, value := range s.values {
+ out[key] = value
+ }
+ return out, nil
+}
+
+func (s *settingHandlerRepoStub) Delete(ctx context.Context, key string) error {
+ panic("unexpected Delete call")
+}
+
+func TestSettingHandler_GetSettings_InjectsAuthSourceDefaults(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ repo := &settingHandlerRepoStub{
+ values: map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ service.SettingKeyPromoCodeEnabled: "true",
+ service.SettingKeyAuthSourceDefaultEmailBalance: "9.5",
+ service.SettingKeyAuthSourceDefaultEmailConcurrency: "8",
+ service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":31,"validity_days":15}]`,
+ service.SettingKeyForceEmailOnThirdPartySignup: "true",
+ },
+ }
+ svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
+ handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/admin/settings", nil)
+
+ handler.GetSettings(c)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ var resp response.Response
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
+ data, ok := resp.Data.(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, 9.5, data["auth_source_default_email_balance"])
+ require.Equal(t, float64(8), data["auth_source_default_email_concurrency"])
+ require.Equal(t, true, data["force_email_on_third_party_signup"])
+
+ subscriptions, ok := data["auth_source_default_email_subscriptions"].([]any)
+ require.True(t, ok)
+ require.Len(t, subscriptions, 1)
+}
+
+func TestSettingHandler_UpdateSettings_PreservesOmittedAuthSourceDefaults(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ repo := &settingHandlerRepoStub{
+ values: map[string]string{
+ service.SettingKeyRegistrationEnabled: "false",
+ service.SettingKeyPromoCodeEnabled: "true",
+ service.SettingKeyAuthSourceDefaultEmailBalance: "9.5",
+ service.SettingKeyAuthSourceDefaultEmailConcurrency: "8",
+ service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":31,"validity_days":15}]`,
+ service.SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true",
+ service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "false",
+ service.SettingKeyForceEmailOnThirdPartySignup: "true",
+ },
+ }
+ svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
+ handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
+
+ body := map[string]any{
+ "registration_enabled": true,
+ "promo_code_enabled": true,
+ "auth_source_default_email_balance": 12.75,
+ }
+ rawBody, err := json.Marshal(body)
+ require.NoError(t, err)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ handler.UpdateSettings(c)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, "12.75000000", repo.values[service.SettingKeyAuthSourceDefaultEmailBalance])
+ require.Equal(t, "8", repo.values[service.SettingKeyAuthSourceDefaultEmailConcurrency])
+ require.Equal(t, `[{"group_id":31,"validity_days":15}]`, repo.values[service.SettingKeyAuthSourceDefaultEmailSubscriptions])
+ require.Equal(t, "true", repo.values[service.SettingKeyForceEmailOnThirdPartySignup])
+
+ var resp response.Response
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
+ data, ok := resp.Data.(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, 12.75, data["auth_source_default_email_balance"])
+ require.Equal(t, float64(8), data["auth_source_default_email_concurrency"])
+ require.Equal(t, true, data["force_email_on_third_party_signup"])
+}
diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go
index 2f182642..b0edcf5a 100644
--- a/backend/internal/handler/auth_linuxdo_oauth.go
+++ b/backend/internal/handler/auth_linuxdo_oauth.go
@@ -219,7 +219,7 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
}
// 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired
- tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
+ tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
if err != nil {
if errors.Is(err, service.ErrOAuthInvitationRequired) {
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
@@ -262,6 +262,7 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
ProviderKey: "linuxdo",
ProviderSubject: subject,
},
+ TargetUserID: &user.ID,
ResolvedEmail: email,
RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey,
@@ -287,7 +288,9 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
}
type completeLinuxDoOAuthRequest struct {
- InvitationCode string `json:"invitation_code" binding:"required"`
+ InvitationCode string `json:"invitation_code" binding:"required"`
+ AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
+ AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
}
// CompleteLinuxDoOAuthRegistration completes a pending OAuth registration by validating
@@ -335,11 +338,23 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
return
}
- tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
+ tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
if err != nil {
response.ErrorFrom(c, err)
return
}
+ decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
+ AdoptDisplayName: req.AdoptDisplayName,
+ AdoptAvatar: req.AdoptAvatar,
+ })
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), session, decision, &user.ID); err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
+ return
+ }
if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)
diff --git a/backend/internal/handler/auth_linuxdo_oauth_test.go b/backend/internal/handler/auth_linuxdo_oauth_test.go
index 90bc10d1..661c0da0 100644
--- a/backend/internal/handler/auth_linuxdo_oauth_test.go
+++ b/backend/internal/handler/auth_linuxdo_oauth_test.go
@@ -1,10 +1,21 @@
package handler
import (
+ "bytes"
+ "context"
+ "net/http"
+ "net/http/httptest"
"strings"
"testing"
+ "time"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
@@ -110,3 +121,79 @@ func TestSingleLineStripsWhitespace(t *testing.T) {
require.Equal(t, "hello world", singleLine("hello\r\nworld"))
require.Equal(t, "", singleLine("\n\t\r"))
}
+
+func TestCompleteLinuxDoOAuthRegistrationAppliesPendingAdoptionDecision(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("linuxdo-complete-session").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("linuxdo-subject-1").
+ SetResolvedEmail("linuxdo-subject-1@linuxdo-connect.invalid").
+ SetBrowserSessionKey("linuxdo-browser").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "linuxdo_user",
+ "suggested_display_name": "LinuxDo Display",
+ "suggested_avatar_url": "https://cdn.example/linuxdo.png",
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = service.NewAuthPendingIdentityService(client).UpsertAdoptionDecision(ctx, service.PendingIdentityAdoptionDecisionInput{
+ PendingAuthSessionID: session.ID,
+ AdoptAvatar: true,
+ })
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1","adopt_display_name":true}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-browser")})
+ c.Request = req
+
+ handler.CompleteLinuxDoOAuthRegistration(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ responseData := decodeJSONBody(t, recorder)
+ require.NotEmpty(t, responseData["access_token"])
+
+ userEntity, err := client.User.Query().
+ Where(dbuser.EmailEQ(session.ResolvedEmail)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "LinuxDo Display", userEntity.Username)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("linuxdo"),
+ authidentity.ProviderKeyEQ("linuxdo"),
+ authidentity.ProviderSubjectEQ("linuxdo-subject-1"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, userEntity.ID, identity.UserID)
+ require.Equal(t, "LinuxDo Display", identity.Metadata["display_name"])
+ require.Equal(t, "https://cdn.example/linuxdo.png", identity.Metadata["avatar_url"])
+
+ decision, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, decision.IdentityID)
+ require.Equal(t, identity.ID, *decision.IdentityID)
+ require.True(t, decision.AdoptDisplayName)
+ require.True(t, decision.AdoptAvatar)
+
+ consumed, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, consumed.ConsumedAt)
+}
diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go
index a758c0b9..da8ac858 100644
--- a/backend/internal/handler/auth_oauth_pending_flow.go
+++ b/backend/internal/handler/auth_oauth_pending_flow.go
@@ -1,10 +1,17 @@
package handler
import (
+ "context"
+ "errors"
+ "io"
"net/http"
"net/url"
"strings"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ dbuser "github.com/Wei-Shaw/sub2api/ent/user"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
@@ -26,6 +33,7 @@ const (
type oauthPendingSessionPayload struct {
Intent string
Identity service.PendingAuthIdentityKey
+ TargetUserID *int64
ResolvedEmail string
RedirectTo string
BrowserSessionKey string
@@ -33,6 +41,11 @@ type oauthPendingSessionPayload struct {
CompletionResponse map[string]any
}
+type oauthAdoptionDecisionRequest struct {
+ AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
+ AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
+}
+
func (h *AuthHandler) pendingIdentityService() (*service.AuthPendingIdentityService, error) {
if h == nil || h.authService == nil || h.authService.EntClient() == nil {
return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
@@ -125,6 +138,7 @@ func (h *AuthHandler) createOAuthPendingSession(c *gin.Context, payload oauthPen
session, err := svc.CreatePendingSession(c.Request.Context(), service.CreatePendingAuthSessionInput{
Intent: strings.TrimSpace(payload.Intent),
Identity: payload.Identity,
+ TargetUserID: payload.TargetUserID,
ResolvedEmail: strings.TrimSpace(payload.ResolvedEmail),
RedirectTo: strings.TrimSpace(payload.RedirectTo),
BrowserSessionKey: strings.TrimSpace(payload.BrowserSessionKey),
@@ -175,6 +189,291 @@ func pendingSessionWantsInvitation(payload map[string]any) bool {
return strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "error")), "invitation_required")
}
+func (r oauthAdoptionDecisionRequest) hasDecision() bool {
+ return r.AdoptDisplayName != nil || r.AdoptAvatar != nil
+}
+
+func (r oauthAdoptionDecisionRequest) toServiceInput(sessionID int64) service.PendingIdentityAdoptionDecisionInput {
+ input := service.PendingIdentityAdoptionDecisionInput{
+ PendingAuthSessionID: sessionID,
+ }
+ if r.AdoptDisplayName != nil {
+ input.AdoptDisplayName = *r.AdoptDisplayName
+ }
+ if r.AdoptAvatar != nil {
+ input.AdoptAvatar = *r.AdoptAvatar
+ }
+ return input
+}
+
+func bindOptionalOAuthAdoptionDecision(c *gin.Context) (oauthAdoptionDecisionRequest, error) {
+ var req oauthAdoptionDecisionRequest
+ if c == nil || c.Request == nil || c.Request.Body == nil {
+ return req, nil
+ }
+ if err := c.ShouldBindJSON(&req); err != nil {
+ if errors.Is(err, io.EOF) {
+ return req, nil
+ }
+ return req, err
+ }
+ return req, nil
+}
+
+func persistPendingOAuthAdoptionDecision(
+ c *gin.Context,
+ svc *service.AuthPendingIdentityService,
+ sessionID int64,
+ req oauthAdoptionDecisionRequest,
+) error {
+ if !req.hasDecision() {
+ return nil
+ }
+ if svc == nil {
+ return infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+ if _, err := svc.UpsertAdoptionDecision(c.Request.Context(), req.toServiceInput(sessionID)); err != nil {
+ return infraerrors.InternalServer("PENDING_AUTH_ADOPTION_SAVE_FAILED", "failed to save oauth profile adoption decision").WithCause(err)
+ }
+ return nil
+}
+
+func cloneOAuthMetadata(values map[string]any) map[string]any {
+ if len(values) == 0 {
+ return map[string]any{}
+ }
+ cloned := make(map[string]any, len(values))
+ for key, value := range values {
+ cloned[key] = value
+ }
+ return cloned
+}
+
+func normalizeAdoptedOAuthDisplayName(value string) string {
+ value = strings.TrimSpace(value)
+ if len([]rune(value)) > 100 {
+ value = string([]rune(value)[:100])
+ }
+ return value
+}
+
+func (h *AuthHandler) entClient() *dbent.Client {
+ if h == nil || h.authService == nil {
+ return nil
+ }
+ return h.authService.EntClient()
+}
+
+func (h *AuthHandler) upsertPendingOAuthAdoptionDecision(
+ c *gin.Context,
+ sessionID int64,
+ req oauthAdoptionDecisionRequest,
+) (*dbent.IdentityAdoptionDecision, error) {
+ client := h.entClient()
+ if client == nil {
+ return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+
+ existing, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(sessionID)).
+ Only(c.Request.Context())
+ if err != nil && !dbent.IsNotFound(err) {
+ return nil, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_LOAD_FAILED", "failed to load oauth profile adoption decision").WithCause(err)
+ }
+ if existing != nil && !req.hasDecision() {
+ return existing, nil
+ }
+ if existing == nil && !req.hasDecision() {
+ return nil, nil
+ }
+
+ input := service.PendingIdentityAdoptionDecisionInput{
+ PendingAuthSessionID: sessionID,
+ }
+ if existing != nil {
+ input.AdoptDisplayName = existing.AdoptDisplayName
+ input.AdoptAvatar = existing.AdoptAvatar
+ input.IdentityID = existing.IdentityID
+ }
+ if req.AdoptDisplayName != nil {
+ input.AdoptDisplayName = *req.AdoptDisplayName
+ }
+ if req.AdoptAvatar != nil {
+ input.AdoptAvatar = *req.AdoptAvatar
+ }
+
+ svc, err := h.pendingIdentityService()
+ if err != nil {
+ return nil, err
+ }
+ decision, err := svc.UpsertAdoptionDecision(c.Request.Context(), input)
+ if err != nil {
+ return nil, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_SAVE_FAILED", "failed to save oauth profile adoption decision").WithCause(err)
+ }
+ return decision, nil
+}
+
+func resolvePendingOAuthTargetUserID(ctx context.Context, client *dbent.Client, session *dbent.PendingAuthSession) (int64, error) {
+ if session == nil {
+ return 0, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth session is invalid")
+ }
+ if session.TargetUserID != nil && *session.TargetUserID > 0 {
+ return *session.TargetUserID, nil
+ }
+ email := strings.TrimSpace(session.ResolvedEmail)
+ if email == "" {
+ return 0, infraerrors.BadRequest("PENDING_AUTH_TARGET_USER_MISSING", "pending auth target user is missing")
+ }
+
+ userEntity, err := client.User.Query().
+ Where(dbuser.EmailEQ(email)).
+ Only(ctx)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return 0, infraerrors.InternalServer("PENDING_AUTH_TARGET_USER_NOT_FOUND", "pending auth target user was not found")
+ }
+ return 0, err
+ }
+ return userEntity.ID, nil
+}
+
+func oauthIdentityIssuer(session *dbent.PendingAuthSession) *string {
+ if session == nil {
+ return nil
+ }
+ switch strings.TrimSpace(session.ProviderType) {
+ case "oidc":
+ issuer := strings.TrimSpace(session.ProviderKey)
+ if issuer == "" {
+ issuer = pendingSessionStringValue(session.UpstreamIdentityClaims, "issuer")
+ }
+ if issuer == "" {
+ return nil
+ }
+ return &issuer
+ default:
+ issuer := pendingSessionStringValue(session.UpstreamIdentityClaims, "issuer")
+ if issuer == "" {
+ return nil
+ }
+ return &issuer
+ }
+}
+
+func ensurePendingOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, session *dbent.PendingAuthSession, userID int64) (*dbent.AuthIdentity, error) {
+ client := tx.Client()
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ(strings.TrimSpace(session.ProviderType)),
+ authidentity.ProviderKeyEQ(strings.TrimSpace(session.ProviderKey)),
+ authidentity.ProviderSubjectEQ(strings.TrimSpace(session.ProviderSubject)),
+ ).
+ Only(ctx)
+ if err != nil && !dbent.IsNotFound(err) {
+ return nil, err
+ }
+ if identity != nil {
+ if identity.UserID != userID {
+ return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
+ }
+ return identity, nil
+ }
+
+ create := client.AuthIdentity.Create().
+ SetUserID(userID).
+ SetProviderType(strings.TrimSpace(session.ProviderType)).
+ SetProviderKey(strings.TrimSpace(session.ProviderKey)).
+ SetProviderSubject(strings.TrimSpace(session.ProviderSubject)).
+ SetMetadata(cloneOAuthMetadata(session.UpstreamIdentityClaims))
+ if issuer := oauthIdentityIssuer(session); issuer != nil {
+ create = create.SetIssuer(strings.TrimSpace(*issuer))
+ }
+ return create.Save(ctx)
+}
+
+func applyPendingOAuthAdoption(
+ ctx context.Context,
+ client *dbent.Client,
+ session *dbent.PendingAuthSession,
+ decision *dbent.IdentityAdoptionDecision,
+ overrideUserID *int64,
+) error {
+ if client == nil || session == nil || decision == nil {
+ return nil
+ }
+ if !decision.AdoptDisplayName && !decision.AdoptAvatar {
+ return nil
+ }
+
+ targetUserID := int64(0)
+ if overrideUserID != nil && *overrideUserID > 0 {
+ targetUserID = *overrideUserID
+ } else {
+ resolvedUserID, err := resolvePendingOAuthTargetUserID(ctx, client, session)
+ if err != nil {
+ return err
+ }
+ targetUserID = resolvedUserID
+ }
+
+ adoptedDisplayName := ""
+ if decision.AdoptDisplayName {
+ adoptedDisplayName = normalizeAdoptedOAuthDisplayName(pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_display_name"))
+ }
+ adoptedAvatarURL := ""
+ if decision.AdoptAvatar {
+ adoptedAvatarURL = pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_avatar_url")
+ }
+
+ tx, err := client.Tx(ctx)
+ if err != nil {
+ return err
+ }
+ defer func() { _ = tx.Rollback() }()
+
+ if decision.AdoptDisplayName && adoptedDisplayName != "" {
+ if err := tx.Client().User.UpdateOneID(targetUserID).
+ SetUsername(adoptedDisplayName).
+ Exec(ctx); err != nil {
+ return err
+ }
+ }
+
+ identity, err := ensurePendingOAuthIdentityForUser(ctx, tx, session, targetUserID)
+ if err != nil {
+ return err
+ }
+
+ metadata := cloneOAuthMetadata(identity.Metadata)
+ for key, value := range session.UpstreamIdentityClaims {
+ metadata[key] = value
+ }
+ if decision.AdoptDisplayName && adoptedDisplayName != "" {
+ metadata["display_name"] = adoptedDisplayName
+ }
+ if decision.AdoptAvatar && adoptedAvatarURL != "" {
+ metadata["avatar_url"] = adoptedAvatarURL
+ }
+
+ updateIdentity := tx.Client().AuthIdentity.UpdateOneID(identity.ID).SetMetadata(metadata)
+ if issuer := oauthIdentityIssuer(session); issuer != nil {
+ updateIdentity = updateIdentity.SetIssuer(strings.TrimSpace(*issuer))
+ }
+ if _, err := updateIdentity.Save(ctx); err != nil {
+ return err
+ }
+
+ if decision.IdentityID == nil || *decision.IdentityID != identity.ID {
+ if _, err := tx.Client().IdentityAdoptionDecision.UpdateOneID(decision.ID).
+ SetIdentityID(identity.ID).
+ Save(ctx); err != nil {
+ return err
+ }
+ }
+
+ return tx.Commit()
+}
+
func applySuggestedProfileToCompletionResponse(payload map[string]any, upstream map[string]any) {
if len(payload) == 0 || len(upstream) == 0 {
return
@@ -206,6 +505,11 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)
}
+ adoptionDecision, err := bindOptionalOAuthAdoptionDecision(c)
+ if err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
sessionToken, err := readOAuthPendingSessionCookie(c)
if err != nil || strings.TrimSpace(sessionToken) == "" {
@@ -248,9 +552,30 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
applySuggestedProfileToCompletionResponse(payload, session.UpstreamIdentityClaims)
if pendingSessionWantsInvitation(payload) {
+ if adoptionDecision.hasDecision() {
+ decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, adoptionDecision)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ _ = decision
+ }
+ response.Success(c, payload)
+ return
+ }
+ if !adoptionDecision.hasDecision() {
response.Success(c, payload)
return
}
+ decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, adoptionDecision)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), session, decision, session.TargetUserID); err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
+ return
+ }
if _, err := svc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
clearCookies()
diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go
index 5517bae2..829fc217 100644
--- a/backend/internal/handler/auth_oauth_pending_flow_test.go
+++ b/backend/internal/handler/auth_oauth_pending_flow_test.go
@@ -1,9 +1,30 @@
package handler
import (
+ "bytes"
+ "context"
+ "database/sql"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
"testing"
+ "time"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ dbuser "github.com/Wei-Shaw/sub2api/ent/user"
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
)
func TestApplySuggestedProfileToCompletionResponse(t *testing.T) {
@@ -38,3 +59,439 @@ func TestApplySuggestedProfileToCompletionResponseKeepsExistingPayloadValues(t *
require.Equal(t, "https://cdn.example/avatar.png", payload["suggested_avatar_url"])
require.Equal(t, true, payload["adoption_required"])
}
+
+func TestExchangePendingOAuthCompletionPreviewThenFinalizeAppliesAdoptionDecision(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ userEntity, err := client.User.Create().
+ SetEmail("linuxdo-123@linuxdo-connect.invalid").
+ SetUsername("legacy-name").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("pending-session-token").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("123").
+ SetTargetUserID(userEntity.ID).
+ SetResolvedEmail(userEntity.Email).
+ SetBrowserSessionKey("browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "linuxdo_user",
+ "suggested_display_name": "Alice Example",
+ "suggested_avatar_url": "https://cdn.example/alice.png",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "access_token": "access-token",
+ "redirect": "/dashboard",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ previewRecorder := httptest.NewRecorder()
+ previewCtx, _ := gin.CreateTestContext(previewRecorder)
+ previewReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil)
+ previewReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ previewReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-session-key")})
+ previewCtx.Request = previewReq
+
+ handler.ExchangePendingOAuthCompletion(previewCtx)
+
+ require.Equal(t, http.StatusOK, previewRecorder.Code)
+ previewData := decodeJSONResponseData(t, previewRecorder)
+ require.Equal(t, "Alice Example", previewData["suggested_display_name"])
+ require.Equal(t, "https://cdn.example/alice.png", previewData["suggested_avatar_url"])
+ require.Equal(t, true, previewData["adoption_required"])
+
+ storedUser, err := client.User.Get(ctx, userEntity.ID)
+ require.NoError(t, err)
+ require.Equal(t, "legacy-name", storedUser.Username)
+
+ previewSession, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Nil(t, previewSession.ConsumedAt)
+
+ body := bytes.NewBufferString(`{"adopt_display_name":true,"adopt_avatar":true}`)
+ finalizeRecorder := httptest.NewRecorder()
+ finalizeCtx, _ := gin.CreateTestContext(finalizeRecorder)
+ finalizeReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body)
+ finalizeReq.Header.Set("Content-Type", "application/json")
+ finalizeReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ finalizeReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-session-key")})
+ finalizeCtx.Request = finalizeReq
+
+ handler.ExchangePendingOAuthCompletion(finalizeCtx)
+
+ require.Equal(t, http.StatusOK, finalizeRecorder.Code)
+
+ storedUser, err = client.User.Get(ctx, userEntity.ID)
+ require.NoError(t, err)
+ require.Equal(t, "Alice Example", storedUser.Username)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("linuxdo"),
+ authidentity.ProviderKeyEQ("linuxdo"),
+ authidentity.ProviderSubjectEQ("123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, userEntity.ID, identity.UserID)
+ require.Equal(t, "Alice Example", identity.Metadata["display_name"])
+ require.Equal(t, "https://cdn.example/alice.png", identity.Metadata["avatar_url"])
+
+ decision, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, decision.IdentityID)
+ require.Equal(t, identity.ID, *decision.IdentityID)
+ require.True(t, decision.AdoptDisplayName)
+ require.True(t, decision.AdoptAvatar)
+
+ consumed, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, consumed.ConsumedAt)
+}
+
+func newOAuthPendingFlowTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) {
+ t.Helper()
+
+ db, err := sql.Open("sqlite", "file:auth_oauth_pending_flow_handler?mode=memory&cache=shared")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+
+ cfg := &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-secret",
+ ExpireHour: 1,
+ AccessTokenExpireMinutes: 60,
+ RefreshTokenExpireDays: 7,
+ },
+ Default: config.DefaultConfig{
+ UserBalance: 0,
+ UserConcurrency: 1,
+ },
+ }
+ settingSvc := service.NewSettingService(&oauthPendingFlowSettingRepoStub{
+ values: map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled),
+ },
+ }, cfg)
+ authSvc := service.NewAuthService(
+ client,
+ &oauthPendingFlowUserRepo{client: client},
+ nil,
+ &oauthPendingFlowRefreshTokenCacheStub{},
+ cfg,
+ settingSvc,
+ nil,
+ nil,
+ nil,
+ nil,
+ nil,
+ )
+
+ return &AuthHandler{
+ authService: authSvc,
+ settingSvc: settingSvc,
+ }, client
+}
+
+func boolSettingValue(v bool) string {
+ if v {
+ return "true"
+ }
+ return "false"
+}
+
+func boolPtr(v bool) *bool {
+ return &v
+}
+
+type oauthPendingFlowSettingRepoStub struct {
+ values map[string]string
+}
+
+func (s *oauthPendingFlowSettingRepoStub) Get(context.Context, string) (*service.Setting, error) {
+ return nil, service.ErrSettingNotFound
+}
+
+func (s *oauthPendingFlowSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
+ value, ok := s.values[key]
+ if !ok {
+ return "", service.ErrSettingNotFound
+ }
+ return value, nil
+}
+
+func (s *oauthPendingFlowSettingRepoStub) Set(context.Context, string, string) error {
+ return nil
+}
+
+func (s *oauthPendingFlowSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
+ result := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if value, ok := s.values[key]; ok {
+ result[key] = value
+ }
+ }
+ return result, nil
+}
+
+func (s *oauthPendingFlowSettingRepoStub) SetMultiple(context.Context, map[string]string) error {
+ return nil
+}
+
+func (s *oauthPendingFlowSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
+ result := make(map[string]string, len(s.values))
+ for key, value := range s.values {
+ result[key] = value
+ }
+ return result, nil
+}
+
+func (s *oauthPendingFlowSettingRepoStub) Delete(context.Context, string) error {
+ return nil
+}
+
+type oauthPendingFlowRefreshTokenCacheStub struct{}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) StoreRefreshToken(context.Context, string, *service.RefreshTokenData, time.Duration) error {
+ return nil
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) GetRefreshToken(context.Context, string) (*service.RefreshTokenData, error) {
+ return nil, service.ErrRefreshTokenNotFound
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error {
+ return nil
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) DeleteUserRefreshTokens(context.Context, int64) error {
+ return nil
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error {
+ return nil
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error {
+ return nil
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error {
+ return nil
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) {
+ return nil, nil
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) {
+ return nil, nil
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) {
+ return false, nil
+}
+
+func decodeJSONResponseData(t *testing.T, recorder *httptest.ResponseRecorder) map[string]any {
+ t.Helper()
+
+ var envelope struct {
+ Data map[string]any `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &envelope))
+ return envelope.Data
+}
+
+func decodeJSONBody(t *testing.T, recorder *httptest.ResponseRecorder) map[string]any {
+ t.Helper()
+
+ var payload map[string]any
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload))
+ return payload
+}
+
+type oauthPendingFlowUserRepo struct {
+ client *dbent.Client
+}
+
+func (r *oauthPendingFlowUserRepo) Create(ctx context.Context, user *service.User) error {
+ entity, err := r.client.User.Create().
+ SetEmail(user.Email).
+ SetUsername(user.Username).
+ SetNotes(user.Notes).
+ SetPasswordHash(user.PasswordHash).
+ SetRole(user.Role).
+ SetBalance(user.Balance).
+ SetConcurrency(user.Concurrency).
+ SetStatus(user.Status).
+ SetSignupSource(user.SignupSource).
+ SetNillableLastLoginAt(user.LastLoginAt).
+ SetNillableLastActiveAt(user.LastActiveAt).
+ Save(ctx)
+ if err != nil {
+ return err
+ }
+ user.ID = entity.ID
+ user.CreatedAt = entity.CreatedAt
+ user.UpdatedAt = entity.UpdatedAt
+ return nil
+}
+
+func (r *oauthPendingFlowUserRepo) GetByID(ctx context.Context, id int64) (*service.User, error) {
+ entity, err := r.client.User.Get(ctx, id)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, service.ErrUserNotFound
+ }
+ return nil, err
+ }
+ return oauthPendingFlowServiceUser(entity), nil
+}
+
+func (r *oauthPendingFlowUserRepo) GetByEmail(ctx context.Context, email string) (*service.User, error) {
+ entity, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Only(ctx)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, service.ErrUserNotFound
+ }
+ return nil, err
+ }
+ return oauthPendingFlowServiceUser(entity), nil
+}
+
+func (r *oauthPendingFlowUserRepo) GetFirstAdmin(context.Context) (*service.User, error) {
+ panic("unexpected GetFirstAdmin call")
+}
+
+func (r *oauthPendingFlowUserRepo) Update(ctx context.Context, user *service.User) error {
+ entity, err := r.client.User.UpdateOneID(user.ID).
+ SetEmail(user.Email).
+ SetUsername(user.Username).
+ SetNotes(user.Notes).
+ SetPasswordHash(user.PasswordHash).
+ SetRole(user.Role).
+ SetBalance(user.Balance).
+ SetConcurrency(user.Concurrency).
+ SetStatus(user.Status).
+ SetSignupSource(user.SignupSource).
+ SetNillableLastLoginAt(user.LastLoginAt).
+ SetNillableLastActiveAt(user.LastActiveAt).
+ Save(ctx)
+ if err != nil {
+ return err
+ }
+ user.UpdatedAt = entity.UpdatedAt
+ return nil
+}
+
+func (r *oauthPendingFlowUserRepo) Delete(ctx context.Context, id int64) error {
+ return r.client.User.DeleteOneID(id).Exec(ctx)
+}
+
+func (r *oauthPendingFlowUserRepo) GetUserAvatar(context.Context, int64) (*service.UserAvatar, error) {
+ return nil, service.ErrUserNotFound
+}
+
+func (r *oauthPendingFlowUserRepo) UpsertUserAvatar(context.Context, int64, service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
+ panic("unexpected UpsertUserAvatar call")
+}
+
+func (r *oauthPendingFlowUserRepo) DeleteUserAvatar(context.Context, int64) error {
+ return nil
+}
+
+func (r *oauthPendingFlowUserRepo) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
+ panic("unexpected List call")
+}
+
+func (r *oauthPendingFlowUserRepo) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
+ panic("unexpected ListWithFilters call")
+}
+
+func (r *oauthPendingFlowUserRepo) UpdateBalance(context.Context, int64, float64) error {
+ panic("unexpected UpdateBalance call")
+}
+
+func (r *oauthPendingFlowUserRepo) DeductBalance(context.Context, int64, float64) error {
+ panic("unexpected DeductBalance call")
+}
+
+func (r *oauthPendingFlowUserRepo) UpdateConcurrency(context.Context, int64, int) error {
+ panic("unexpected UpdateConcurrency call")
+}
+
+func (r *oauthPendingFlowUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) {
+ count, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Count(ctx)
+ return count > 0, err
+}
+
+func (r *oauthPendingFlowUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
+ panic("unexpected RemoveGroupFromAllowedGroups call")
+}
+
+func (r *oauthPendingFlowUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error {
+ panic("unexpected AddGroupToAllowedGroups call")
+}
+
+func (r *oauthPendingFlowUserRepo) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
+ panic("unexpected RemoveGroupFromUserAllowedGroups call")
+}
+
+func (r *oauthPendingFlowUserRepo) UpdateTotpSecret(context.Context, int64, *string) error {
+ panic("unexpected UpdateTotpSecret call")
+}
+
+func (r *oauthPendingFlowUserRepo) EnableTotp(context.Context, int64) error {
+ panic("unexpected EnableTotp call")
+}
+
+func (r *oauthPendingFlowUserRepo) DisableTotp(context.Context, int64) error {
+ panic("unexpected DisableTotp call")
+}
+
+func oauthPendingFlowServiceUser(entity *dbent.User) *service.User {
+ if entity == nil {
+ return nil
+ }
+ return &service.User{
+ ID: entity.ID,
+ Email: entity.Email,
+ Username: entity.Username,
+ Notes: entity.Notes,
+ PasswordHash: entity.PasswordHash,
+ Role: entity.Role,
+ Balance: entity.Balance,
+ Concurrency: entity.Concurrency,
+ Status: entity.Status,
+ SignupSource: entity.SignupSource,
+ LastLoginAt: entity.LastLoginAt,
+ LastActiveAt: entity.LastActiveAt,
+ CreatedAt: entity.CreatedAt,
+ UpdatedAt: entity.UpdatedAt,
+ }
+}
diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go
index e3694c8f..ceda633c 100644
--- a/backend/internal/handler/auth_oidc_oauth.go
+++ b/backend/internal/handler/auth_oidc_oauth.go
@@ -326,7 +326,7 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
)
// 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired
- tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
+ tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
if err != nil {
if errors.Is(err, service.ErrOAuthInvitationRequired) {
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
@@ -371,6 +371,7 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
ProviderKey: issuer,
ProviderSubject: subject,
},
+ TargetUserID: &user.ID,
ResolvedEmail: email,
RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey,
@@ -399,7 +400,9 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
}
type completeOIDCOAuthRequest struct {
- InvitationCode string `json:"invitation_code" binding:"required"`
+ InvitationCode string `json:"invitation_code" binding:"required"`
+ AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
+ AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
}
// CompleteOIDCOAuthRegistration completes a pending OAuth registration by validating
@@ -447,11 +450,23 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
return
}
- tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
+ tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
if err != nil {
response.ErrorFrom(c, err)
return
}
+ decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
+ AdoptDisplayName: req.AdoptDisplayName,
+ AdoptAvatar: req.AdoptAvatar,
+ })
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), session, decision, &user.ID); err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
+ return
+ }
if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)
diff --git a/backend/internal/handler/auth_oidc_oauth_test.go b/backend/internal/handler/auth_oidc_oauth_test.go
index c389db51..9107e13a 100644
--- a/backend/internal/handler/auth_oidc_oauth_test.go
+++ b/backend/internal/handler/auth_oidc_oauth_test.go
@@ -1,6 +1,7 @@
package handler
import (
+ "bytes"
"context"
"crypto/rand"
"crypto/rsa"
@@ -12,7 +13,13 @@ import (
"testing"
"time"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/require"
)
@@ -123,3 +130,80 @@ func buildRSAJWK(kid string, pub *rsa.PublicKey) oidcJWK {
E: e,
}
}
+
+func TestCompleteOIDCOAuthRegistrationAppliesPendingAdoptionDecision(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("oidc-complete-session").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example.com").
+ SetProviderSubject("oidc-subject-1").
+ SetResolvedEmail("93a310f4c1944c5bbd2e246df1f76485@oidc-connect.invalid").
+ SetBrowserSessionKey("oidc-browser").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ "issuer": "https://issuer.example.com",
+ "suggested_display_name": "OIDC Display",
+ "suggested_avatar_url": "https://cdn.example/oidc.png",
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = service.NewAuthPendingIdentityService(client).UpsertAdoptionDecision(ctx, service.PendingIdentityAdoptionDecisionInput{
+ PendingAuthSessionID: session.ID,
+ AdoptAvatar: true,
+ })
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1","adopt_display_name":true}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-browser")})
+ c.Request = req
+
+ handler.CompleteOIDCOAuthRegistration(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ responseData := decodeJSONBody(t, recorder)
+ require.NotEmpty(t, responseData["access_token"])
+
+ userEntity, err := client.User.Query().
+ Where(dbuser.EmailEQ(session.ResolvedEmail)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "OIDC Display", userEntity.Username)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("oidc"),
+ authidentity.ProviderKeyEQ("https://issuer.example.com"),
+ authidentity.ProviderSubjectEQ("oidc-subject-1"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, userEntity.ID, identity.UserID)
+ require.Equal(t, "OIDC Display", identity.Metadata["display_name"])
+ require.Equal(t, "https://cdn.example/oidc.png", identity.Metadata["avatar_url"])
+
+ decision, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, decision.IdentityID)
+ require.Equal(t, identity.ID, *decision.IdentityID)
+ require.True(t, decision.AdoptDisplayName)
+ require.True(t, decision.AdoptAvatar)
+
+ consumed, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, consumed.ConsumedAt)
+}
diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go
new file mode 100644
index 00000000..867a77a1
--- /dev/null
+++ b/backend/internal/handler/auth_wechat_oauth.go
@@ -0,0 +1,618 @@
+package handler
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "os"
+ "strings"
+ "time"
+
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+
+ "github.com/gin-gonic/gin"
+)
+
+const (
+ wechatOAuthCookiePath = "/api/v1/auth/oauth/wechat"
+ wechatOAuthCookieMaxAgeSec = 10 * 60
+ wechatOAuthStateCookieName = "wechat_oauth_state"
+ wechatOAuthRedirectCookieName = "wechat_oauth_redirect"
+ wechatOAuthIntentCookieName = "wechat_oauth_intent"
+ wechatOAuthModeCookieName = "wechat_oauth_mode"
+ wechatOAuthDefaultRedirectTo = "/dashboard"
+ wechatOAuthDefaultFrontendCB = "/auth/wechat/callback"
+ wechatOAuthProviderKey = "wechat-main"
+
+ wechatOAuthIntentLogin = "login"
+ wechatOAuthIntentBind = "bind_current_user"
+ wechatOAuthIntentAdoptEmail = "adopt_existing_user_by_email"
+)
+
+var (
+ wechatOAuthAccessTokenURL = "https://api.weixin.qq.com/sns/oauth2/access_token"
+ wechatOAuthUserInfoURL = "https://api.weixin.qq.com/sns/userinfo"
+)
+
+type wechatOAuthConfig struct {
+ mode string
+ appID string
+ appSecret string
+ authorizeURL string
+ scope string
+ redirectURI string
+ frontendCallback string
+}
+
+type wechatOAuthTokenResponse struct {
+ AccessToken string `json:"access_token"`
+ ExpiresIn int64 `json:"expires_in"`
+ RefreshToken string `json:"refresh_token"`
+ OpenID string `json:"openid"`
+ Scope string `json:"scope"`
+ UnionID string `json:"unionid"`
+ ErrCode int64 `json:"errcode"`
+ ErrMsg string `json:"errmsg"`
+}
+
+type wechatOAuthUserInfoResponse struct {
+ OpenID string `json:"openid"`
+ Nickname string `json:"nickname"`
+ HeadImgURL string `json:"headimgurl"`
+ UnionID string `json:"unionid"`
+ ErrCode int64 `json:"errcode"`
+ ErrMsg string `json:"errmsg"`
+}
+
+// WeChatOAuthStart starts the WeChat OAuth login flow and stores the short-lived
+// browser cookies required by the rebuild pending-auth bridge.
+func (h *AuthHandler) WeChatOAuthStart(c *gin.Context) {
+ cfg, err := h.getWeChatOAuthConfig(c.Request.Context(), c.Query("mode"), c)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ state, err := oauth.GenerateState()
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_STATE_GEN_FAILED", "failed to generate oauth state").WithCause(err))
+ return
+ }
+
+ redirectTo := sanitizeFrontendRedirectPath(c.Query("redirect"))
+ if redirectTo == "" {
+ redirectTo = wechatOAuthDefaultRedirectTo
+ }
+
+ browserSessionKey, err := generateOAuthPendingBrowserSession()
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BROWSER_SESSION_GEN_FAILED", "failed to generate oauth browser session").WithCause(err))
+ return
+ }
+
+ intent := normalizeWeChatOAuthIntent(c.Query("intent"))
+ secureCookie := isRequestHTTPS(c)
+ wechatSetCookie(c, wechatOAuthStateCookieName, encodeCookieValue(state), wechatOAuthCookieMaxAgeSec, secureCookie)
+ wechatSetCookie(c, wechatOAuthRedirectCookieName, encodeCookieValue(redirectTo), wechatOAuthCookieMaxAgeSec, secureCookie)
+ wechatSetCookie(c, wechatOAuthIntentCookieName, encodeCookieValue(intent), wechatOAuthCookieMaxAgeSec, secureCookie)
+ wechatSetCookie(c, wechatOAuthModeCookieName, encodeCookieValue(cfg.mode), wechatOAuthCookieMaxAgeSec, secureCookie)
+ setOAuthPendingBrowserCookie(c, browserSessionKey, secureCookie)
+ clearOAuthPendingSessionCookie(c, secureCookie)
+
+ authURL, err := buildWeChatAuthorizeURL(cfg, state)
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BUILD_URL_FAILED", "failed to build oauth authorization url").WithCause(err))
+ return
+ }
+
+ c.Redirect(http.StatusFound, authURL)
+}
+
+// WeChatOAuthCallback exchanges the code with WeChat, resolves openid/unionid,
+// and stores the result in the unified pending-auth flow.
+func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) {
+ frontendCallback := wechatOAuthFrontendCallback()
+
+ if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" {
+ redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description"))
+ return
+ }
+
+ code := strings.TrimSpace(c.Query("code"))
+ state := strings.TrimSpace(c.Query("state"))
+ if code == "" || state == "" {
+ redirectOAuthError(c, frontendCallback, "missing_params", "missing code/state", "")
+ return
+ }
+
+ secureCookie := isRequestHTTPS(c)
+ defer func() {
+ wechatClearCookie(c, wechatOAuthStateCookieName, secureCookie)
+ wechatClearCookie(c, wechatOAuthRedirectCookieName, secureCookie)
+ wechatClearCookie(c, wechatOAuthIntentCookieName, secureCookie)
+ wechatClearCookie(c, wechatOAuthModeCookieName, secureCookie)
+ }()
+
+ expectedState, err := readCookieDecoded(c, wechatOAuthStateCookieName)
+ if err != nil || expectedState == "" || state != expectedState {
+ redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth state", "")
+ return
+ }
+
+ redirectTo, _ := readCookieDecoded(c, wechatOAuthRedirectCookieName)
+ redirectTo = sanitizeFrontendRedirectPath(redirectTo)
+ if redirectTo == "" {
+ redirectTo = wechatOAuthDefaultRedirectTo
+ }
+ browserSessionKey, _ := readOAuthPendingBrowserCookie(c)
+ if strings.TrimSpace(browserSessionKey) == "" {
+ redirectOAuthError(c, frontendCallback, "missing_browser_session", "missing oauth browser session", "")
+ return
+ }
+
+ intent, _ := readCookieDecoded(c, wechatOAuthIntentCookieName)
+ mode, err := readCookieDecoded(c, wechatOAuthModeCookieName)
+ if err != nil || strings.TrimSpace(mode) == "" {
+ redirectOAuthError(c, frontendCallback, "invalid_state", "missing oauth mode", "")
+ return
+ }
+
+ cfg, err := h.getWeChatOAuthConfig(c.Request.Context(), mode, c)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "provider_error", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+
+ tokenResp, userInfo, err := fetchWeChatOAuthIdentity(c.Request.Context(), cfg, code)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "provider_error", "wechat_identity_fetch_failed", singleLine(err.Error()))
+ return
+ }
+
+ unionid := strings.TrimSpace(firstNonEmpty(userInfo.UnionID, tokenResp.UnionID))
+ openid := strings.TrimSpace(firstNonEmpty(userInfo.OpenID, tokenResp.OpenID))
+ providerSubject := firstNonEmpty(unionid, openid)
+ if providerSubject == "" {
+ redirectOAuthError(c, frontendCallback, "provider_error", "wechat_missing_subject", "")
+ return
+ }
+
+ username := firstNonEmpty(userInfo.Nickname, wechatFallbackUsername(providerSubject))
+ email := wechatSyntheticEmail(providerSubject)
+ upstreamClaims := map[string]any{
+ "email": email,
+ "username": username,
+ "subject": providerSubject,
+ "openid": openid,
+ "unionid": unionid,
+ "mode": cfg.mode,
+ "suggested_display_name": strings.TrimSpace(userInfo.Nickname),
+ "suggested_avatar_url": strings.TrimSpace(userInfo.HeadImgURL),
+ }
+
+ tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
+ if err != nil {
+ if err := h.createWeChatPendingSession(c, normalizeWeChatOAuthIntent(intent), providerSubject, email, redirectTo, browserSessionKey, upstreamClaims, tokenPair, err); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+ return
+ }
+
+ if err := h.createWeChatPendingSession(c, normalizeWeChatOAuthIntent(intent), providerSubject, email, redirectTo, browserSessionKey, upstreamClaims, tokenPair, nil); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+}
+
+type completeWeChatOAuthRequest struct {
+ InvitationCode string `json:"invitation_code" binding:"required"`
+ AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
+ AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
+}
+
+// CompleteWeChatOAuthRegistration completes a pending WeChat OAuth registration by
+// validating the invitation code and consuming the current pending browser session.
+// POST /api/v1/auth/oauth/wechat/complete-registration
+func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) {
+ var req completeWeChatOAuthRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ c.JSON(http.StatusBadRequest, gin.H{"error": "INVALID_REQUEST", "message": err.Error()})
+ return
+ }
+
+ secureCookie := isRequestHTTPS(c)
+ sessionToken, err := readOAuthPendingSessionCookie(c)
+ if err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, service.ErrPendingAuthSessionNotFound)
+ return
+ }
+ browserSessionKey, err := readOAuthPendingBrowserCookie(c)
+ if err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, service.ErrPendingAuthBrowserMismatch)
+ return
+ }
+ pendingSvc, err := h.pendingIdentityService()
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ session, err := pendingSvc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey)
+ if err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ email := strings.TrimSpace(session.ResolvedEmail)
+ username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username")
+ if email == "" || username == "" {
+ response.ErrorFrom(c, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid"))
+ return
+ }
+
+ tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
+ AdoptDisplayName: req.AdoptDisplayName,
+ AdoptAvatar: req.AdoptAvatar,
+ })
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), session, decision, &user.ID); err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
+ return
+ }
+ if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, err)
+ return
+ }
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+
+ c.JSON(http.StatusOK, gin.H{
+ "access_token": tokenPair.AccessToken,
+ "refresh_token": tokenPair.RefreshToken,
+ "expires_in": tokenPair.ExpiresIn,
+ "token_type": "Bearer",
+ })
+}
+
+func (h *AuthHandler) createWeChatPendingSession(
+ c *gin.Context,
+ intent string,
+ providerSubject string,
+ email string,
+ redirectTo string,
+ browserSessionKey string,
+ upstreamClaims map[string]any,
+ tokenPair *service.TokenPair,
+ authErr error,
+) error {
+ completionResponse := map[string]any{
+ "redirect": redirectTo,
+ }
+ if authErr != nil {
+ if errors.Is(authErr, service.ErrOAuthInvitationRequired) {
+ completionResponse["error"] = "invitation_required"
+ } else {
+ return authErr
+ }
+ } else if tokenPair != nil {
+ completionResponse["access_token"] = tokenPair.AccessToken
+ completionResponse["refresh_token"] = tokenPair.RefreshToken
+ completionResponse["expires_in"] = tokenPair.ExpiresIn
+ completionResponse["token_type"] = "Bearer"
+ }
+
+ return h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: intent,
+ Identity: service.PendingAuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: wechatOAuthProviderKey,
+ ProviderSubject: providerSubject,
+ },
+ ResolvedEmail: email,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
+ CompletionResponse: completionResponse,
+ })
+}
+
+func (h *AuthHandler) getWeChatOAuthConfig(ctx context.Context, rawMode string, c *gin.Context) (wechatOAuthConfig, error) {
+ mode, err := resolveWeChatOAuthMode(rawMode, c)
+ if err != nil {
+ return wechatOAuthConfig{}, err
+ }
+
+ apiBaseURL := ""
+ if h != nil && h.settingSvc != nil {
+ settings, err := h.settingSvc.GetAllSettings(ctx)
+ if err == nil && settings != nil {
+ apiBaseURL = strings.TrimSpace(settings.APIBaseURL)
+ }
+ }
+
+ cfg := wechatOAuthConfig{
+ mode: mode,
+ redirectURI: resolveWeChatOAuthAbsoluteURL(apiBaseURL, c, "/api/v1/auth/oauth/wechat/callback"),
+ frontendCallback: wechatOAuthFrontendCallback(),
+ }
+
+ switch mode {
+ case "mp":
+ cfg.appID = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_ID"))
+ cfg.appSecret = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_SECRET"))
+ cfg.authorizeURL = "https://open.weixin.qq.com/connect/oauth2/authorize"
+ cfg.scope = "snsapi_userinfo"
+ default:
+ cfg.appID = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_ID"))
+ cfg.appSecret = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_SECRET"))
+ cfg.authorizeURL = "https://open.weixin.qq.com/connect/qrconnect"
+ cfg.scope = "snsapi_login"
+ }
+
+ if cfg.appID == "" || cfg.appSecret == "" {
+ return wechatOAuthConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "wechat oauth is disabled")
+ }
+ if strings.TrimSpace(cfg.redirectURI) == "" {
+ return wechatOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth redirect url not configured")
+ }
+
+ return cfg, nil
+}
+
+func wechatOAuthFrontendCallback() string {
+ return firstNonEmpty(strings.TrimSpace(os.Getenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL")), wechatOAuthDefaultFrontendCB)
+}
+
+func resolveWeChatOAuthMode(rawMode string, c *gin.Context) (string, error) {
+ mode := strings.ToLower(strings.TrimSpace(rawMode))
+ if mode == "" {
+ if isWeChatBrowserRequest(c) {
+ return "mp", nil
+ }
+ return "open", nil
+ }
+ if mode != "open" && mode != "mp" {
+ return "", infraerrors.BadRequest("INVALID_MODE", "wechat oauth mode must be open or mp")
+ }
+ return mode, nil
+}
+
+func isWeChatBrowserRequest(c *gin.Context) bool {
+ if c == nil || c.Request == nil {
+ return false
+ }
+ return strings.Contains(strings.ToLower(strings.TrimSpace(c.GetHeader("User-Agent"))), "micromessenger")
+}
+
+func normalizeWeChatOAuthIntent(raw string) string {
+ switch strings.ToLower(strings.TrimSpace(raw)) {
+ case "", "login":
+ return wechatOAuthIntentLogin
+ case "bind", "bind_current_user":
+ return wechatOAuthIntentBind
+ case "adopt", "adopt_existing_user_by_email":
+ return wechatOAuthIntentAdoptEmail
+ default:
+ return wechatOAuthIntentLogin
+ }
+}
+
+func buildWeChatAuthorizeURL(cfg wechatOAuthConfig, state string) (string, error) {
+ u, err := url.Parse(cfg.authorizeURL)
+ if err != nil {
+ return "", fmt.Errorf("parse authorize url: %w", err)
+ }
+ query := u.Query()
+ query.Set("appid", cfg.appID)
+ query.Set("redirect_uri", cfg.redirectURI)
+ query.Set("response_type", "code")
+ query.Set("scope", cfg.scope)
+ query.Set("state", state)
+ u.RawQuery = query.Encode()
+ u.Fragment = "wechat_redirect"
+ return u.String(), nil
+}
+
+func resolveWeChatOAuthAbsoluteURL(apiBaseURL string, c *gin.Context, callbackPath string) string {
+ callbackPath = strings.TrimSpace(callbackPath)
+ if callbackPath == "" {
+ return ""
+ }
+
+ if raw := strings.TrimSpace(apiBaseURL); raw != "" {
+ if parsed, err := url.Parse(raw); err == nil && parsed.Scheme != "" && parsed.Host != "" {
+ basePath := strings.TrimRight(parsed.EscapedPath(), "/")
+ targetPath := callbackPath
+ if basePath != "" && strings.HasSuffix(basePath, "/api/v1") && strings.HasPrefix(callbackPath, "/api/v1") {
+ targetPath = basePath + strings.TrimPrefix(callbackPath, "/api/v1")
+ } else if basePath != "" {
+ targetPath = basePath + callbackPath
+ }
+ return parsed.Scheme + "://" + parsed.Host + targetPath
+ }
+ }
+
+ if c == nil || c.Request == nil {
+ return ""
+ }
+ scheme := "http"
+ if isRequestHTTPS(c) {
+ scheme = "https"
+ }
+ host := strings.TrimSpace(c.Request.Host)
+ if forwardedHost := strings.TrimSpace(c.GetHeader("X-Forwarded-Host")); forwardedHost != "" {
+ host = forwardedHost
+ }
+ if host == "" {
+ return ""
+ }
+ return scheme + "://" + host + callbackPath
+}
+
+func fetchWeChatOAuthIdentity(ctx context.Context, cfg wechatOAuthConfig, code string) (*wechatOAuthTokenResponse, *wechatOAuthUserInfoResponse, error) {
+ tokenResp, err := exchangeWeChatOAuthCode(ctx, cfg, code)
+ if err != nil {
+ return nil, nil, err
+ }
+ userInfo, err := fetchWeChatUserInfo(ctx, tokenResp)
+ if err != nil {
+ return nil, nil, err
+ }
+ return tokenResp, userInfo, nil
+}
+
+func exchangeWeChatOAuthCode(ctx context.Context, cfg wechatOAuthConfig, code string) (*wechatOAuthTokenResponse, error) {
+ endpoint, err := url.Parse(wechatOAuthAccessTokenURL)
+ if err != nil {
+ return nil, fmt.Errorf("parse wechat access token url: %w", err)
+ }
+
+ query := endpoint.Query()
+ query.Set("appid", cfg.appID)
+ query.Set("secret", cfg.appSecret)
+ query.Set("code", strings.TrimSpace(code))
+ query.Set("grant_type", "authorization_code")
+ endpoint.RawQuery = query.Encode()
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint.String(), nil)
+ if err != nil {
+ return nil, fmt.Errorf("build wechat access token request: %w", err)
+ }
+
+ client := &http.Client{Timeout: 30 * time.Second}
+ resp, err := client.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("request wechat access token: %w", err)
+ }
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("read wechat access token response: %w", err)
+ }
+ if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
+ return nil, fmt.Errorf("wechat access token status=%d", resp.StatusCode)
+ }
+
+ var tokenResp wechatOAuthTokenResponse
+ if err := json.Unmarshal(body, &tokenResp); err != nil {
+ return nil, fmt.Errorf("decode wechat access token response: %w", err)
+ }
+ if tokenResp.ErrCode != 0 {
+ return nil, fmt.Errorf("wechat access token error=%d %s", tokenResp.ErrCode, strings.TrimSpace(tokenResp.ErrMsg))
+ }
+ if strings.TrimSpace(tokenResp.AccessToken) == "" {
+ return nil, fmt.Errorf("wechat access token missing access_token")
+ }
+ return &tokenResp, nil
+}
+
+func fetchWeChatUserInfo(ctx context.Context, tokenResp *wechatOAuthTokenResponse) (*wechatOAuthUserInfoResponse, error) {
+ if tokenResp == nil {
+ return nil, fmt.Errorf("wechat token response is nil")
+ }
+
+ endpoint, err := url.Parse(wechatOAuthUserInfoURL)
+ if err != nil {
+ return nil, fmt.Errorf("parse wechat userinfo url: %w", err)
+ }
+ query := endpoint.Query()
+ query.Set("access_token", strings.TrimSpace(tokenResp.AccessToken))
+ query.Set("openid", strings.TrimSpace(tokenResp.OpenID))
+ query.Set("lang", "zh_CN")
+ endpoint.RawQuery = query.Encode()
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint.String(), nil)
+ if err != nil {
+ return nil, fmt.Errorf("build wechat userinfo request: %w", err)
+ }
+
+ client := &http.Client{Timeout: 30 * time.Second}
+ resp, err := client.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("request wechat userinfo: %w", err)
+ }
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("read wechat userinfo response: %w", err)
+ }
+ if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
+ return nil, fmt.Errorf("wechat userinfo status=%d", resp.StatusCode)
+ }
+
+ var userInfo wechatOAuthUserInfoResponse
+ if err := json.Unmarshal(body, &userInfo); err != nil {
+ return nil, fmt.Errorf("decode wechat userinfo response: %w", err)
+ }
+ if userInfo.ErrCode != 0 {
+ return nil, fmt.Errorf("wechat userinfo error=%d %s", userInfo.ErrCode, strings.TrimSpace(userInfo.ErrMsg))
+ }
+ return &userInfo, nil
+}
+
+func wechatSyntheticEmail(subject string) string {
+ subject = strings.TrimSpace(subject)
+ if subject == "" {
+ return ""
+ }
+ return "wechat-" + subject + service.WeChatConnectSyntheticEmailDomain
+}
+
+func wechatFallbackUsername(subject string) string {
+ subject = strings.TrimSpace(subject)
+ if subject == "" {
+ return "wechat_user"
+ }
+ return "wechat_" + truncateFragmentValue(subject)
+}
+
+func wechatSetCookie(c *gin.Context, name string, value string, maxAgeSec int, secure bool) {
+ http.SetCookie(c.Writer, &http.Cookie{
+ Name: name,
+ Value: value,
+ Path: wechatOAuthCookiePath,
+ MaxAge: maxAgeSec,
+ HttpOnly: true,
+ Secure: secure,
+ SameSite: http.SameSiteLaxMode,
+ })
+}
+
+func wechatClearCookie(c *gin.Context, name string, secure bool) {
+ http.SetCookie(c.Writer, &http.Cookie{
+ Name: name,
+ Value: "",
+ Path: wechatOAuthCookiePath,
+ MaxAge: -1,
+ HttpOnly: true,
+ Secure: secure,
+ SameSite: http.SameSiteLaxMode,
+ })
+}
diff --git a/backend/internal/handler/auth_wechat_oauth_test.go b/backend/internal/handler/auth_wechat_oauth_test.go
new file mode 100644
index 00000000..1a765dcc
--- /dev/null
+++ b/backend/internal/handler/auth_wechat_oauth_test.go
@@ -0,0 +1,411 @@
+//go:build unit
+
+package handler
+
+import (
+ "bytes"
+ "context"
+ "database/sql"
+ "encoding/base64"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ dbuser "github.com/Wei-Shaw/sub2api/ent/user"
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/repository"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
+)
+
+func TestWeChatOAuthStartRedirectsAndSetsPendingCookies(t *testing.T) {
+ t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
+ t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
+
+ gin.SetMode(gin.TestMode)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/start?mode=open&redirect=/billing", nil)
+ c.Request.Host = "api.example.com"
+
+ handler := &AuthHandler{}
+ handler.WeChatOAuthStart(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ location := recorder.Header().Get("Location")
+ require.NotEmpty(t, location)
+ require.Contains(t, location, "open.weixin.qq.com")
+ require.Contains(t, location, "appid=wx-open-app")
+ require.Contains(t, location, "scope=snsapi_login")
+
+ cookies := recorder.Result().Cookies()
+ require.NotEmpty(t, findCookie(cookies, wechatOAuthStateCookieName))
+ require.NotEmpty(t, findCookie(cookies, wechatOAuthRedirectCookieName))
+ require.NotEmpty(t, findCookie(cookies, wechatOAuthModeCookieName))
+ require.NotEmpty(t, findCookie(cookies, oauthPendingBrowserCookieName))
+}
+
+func TestWeChatOAuthCallbackCreatesPendingSessionForUnifiedFlow(t *testing.T) {
+ t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
+ t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
+ t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback")
+
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ originalUserInfoURL := wechatOAuthUserInfoURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ wechatOAuthUserInfoURL = originalUserInfoURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
+ case strings.Contains(r.URL.Path, "/sns/userinfo"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"WeChat Nick","headimgurl":"https://cdn.example/avatar.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+ wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
+
+ handler, client := newWeChatOAuthTestHandler(t, false)
+ defer client.Close()
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
+ req.Host = "api.example.com"
+ req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
+ req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
+ req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ c.Request = req
+
+ handler.WeChatOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ ctx := context.Background()
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "wechat", session.ProviderType)
+ require.Equal(t, "wechat-main", session.ProviderKey)
+ require.Equal(t, "union-456", session.ProviderSubject)
+ require.Equal(t, "wechat-union-456@wechat-connect.invalid", session.ResolvedEmail)
+ require.Equal(t, "WeChat Nick", session.UpstreamIdentityClaims["suggested_display_name"])
+ require.Equal(t, "https://cdn.example/avatar.png", session.UpstreamIdentityClaims["suggested_avatar_url"])
+ require.Equal(t, "union-456", session.UpstreamIdentityClaims["unionid"])
+ require.Equal(t, "openid-123", session.UpstreamIdentityClaims["openid"])
+}
+
+func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing.T) {
+ t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
+ t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
+ t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback")
+
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ originalUserInfoURL := wechatOAuthUserInfoURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ wechatOAuthUserInfoURL = originalUserInfoURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
+ case strings.Contains(r.URL.Path, "/sns/userinfo"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"WeChat Display","headimgurl":"https://cdn.example/wechat.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+ wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
+
+ handler, client := newWeChatOAuthTestHandler(t, true)
+ defer client.Close()
+
+ ctx := context.Background()
+ redeemRepo := repository.NewRedeemCodeRepository(client)
+ require.NoError(t, redeemRepo.Create(ctx, &service.RedeemCode{
+ Code: "invite-1",
+ Type: service.RedeemTypeInvitation,
+ Status: service.StatusUnused,
+ }))
+
+ callbackRecorder := httptest.NewRecorder()
+ callbackCtx, _ := gin.CreateTestContext(callbackRecorder)
+ callbackReq := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
+ callbackReq.Host = "api.example.com"
+ callbackReq.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
+ callbackReq.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
+ callbackReq.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
+ callbackReq.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ callbackCtx.Request = callbackReq
+
+ handler.WeChatOAuthCallback(callbackCtx)
+
+ require.Equal(t, http.StatusFound, callbackRecorder.Code)
+ require.Equal(t, "/auth/wechat/callback", callbackRecorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(callbackRecorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+ sessionToken := decodeCookieValueForTest(t, sessionCookie.Value)
+
+ pendingSession, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(sessionToken)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "invitation_required", pendingSession.LocalFlowState[oauthCompletionResponseKey].(map[string]any)["error"])
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1","adopt_display_name":true,"adopt_avatar":true}`)
+ completeRecorder := httptest.NewRecorder()
+ completeCtx, _ := gin.CreateTestContext(completeRecorder)
+ completeReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body)
+ completeReq.Header.Set("Content-Type", "application/json")
+ completeReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(sessionToken)})
+ completeReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-123")})
+ completeCtx.Request = completeReq
+
+ handler.CompleteWeChatOAuthRegistration(completeCtx)
+
+ require.Equal(t, http.StatusOK, completeRecorder.Code)
+ responseData := decodeJSONBody(t, completeRecorder)
+ require.NotEmpty(t, responseData["access_token"])
+
+ userEntity, err := client.User.Query().
+ Where(dbuser.EmailEQ("wechat-union-456@wechat-connect.invalid")).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "WeChat Display", userEntity.Username)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("wechat"),
+ authidentity.ProviderKeyEQ("wechat-main"),
+ authidentity.ProviderSubjectEQ("union-456"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, userEntity.ID, identity.UserID)
+ require.Equal(t, "WeChat Display", identity.Metadata["display_name"])
+ require.Equal(t, "https://cdn.example/wechat.png", identity.Metadata["avatar_url"])
+
+ decision, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(pendingSession.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, decision.IdentityID)
+ require.Equal(t, identity.ID, *decision.IdentityID)
+ require.True(t, decision.AdoptDisplayName)
+ require.True(t, decision.AdoptAvatar)
+
+ consumed, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(pendingSession.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, consumed.ConsumedAt)
+}
+
+func newWeChatOAuthTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) {
+ t.Helper()
+
+ db, err := sql.Open("sqlite", "file:auth_wechat_oauth?mode=memory&cache=shared")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+
+ userRepo := &oauthPendingFlowUserRepo{client: client}
+ redeemRepo := repository.NewRedeemCodeRepository(client)
+ settingSvc := service.NewSettingService(&wechatOAuthSettingRepoStub{
+ values: map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled),
+ },
+ }, &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-secret",
+ ExpireHour: 1,
+ AccessTokenExpireMinutes: 60,
+ RefreshTokenExpireDays: 7,
+ },
+ Default: config.DefaultConfig{
+ UserBalance: 0,
+ UserConcurrency: 1,
+ },
+ })
+
+ authSvc := service.NewAuthService(
+ client,
+ userRepo,
+ redeemRepo,
+ &wechatOAuthRefreshTokenCacheStub{},
+ &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-secret",
+ ExpireHour: 1,
+ AccessTokenExpireMinutes: 60,
+ RefreshTokenExpireDays: 7,
+ },
+ Default: config.DefaultConfig{
+ UserBalance: 0,
+ UserConcurrency: 1,
+ },
+ },
+ settingSvc,
+ nil,
+ nil,
+ nil,
+ nil,
+ nil,
+ )
+
+ return &AuthHandler{
+ authService: authSvc,
+ settingSvc: settingSvc,
+ }, client
+}
+
+func encodedCookie(name, value string) *http.Cookie {
+ return &http.Cookie{
+ Name: name,
+ Value: encodeCookieValue(value),
+ Path: "/",
+ }
+}
+
+func findCookie(cookies []*http.Cookie, name string) *http.Cookie {
+ for _, cookie := range cookies {
+ if cookie.Name == name {
+ return cookie
+ }
+ }
+ return nil
+}
+
+func decodeCookieValueForTest(t *testing.T, value string) string {
+ t.Helper()
+ raw, err := base64.RawURLEncoding.DecodeString(value)
+ require.NoError(t, err)
+ return string(raw)
+}
+
+type wechatOAuthSettingRepoStub struct {
+ values map[string]string
+}
+
+func (s *wechatOAuthSettingRepoStub) Get(context.Context, string) (*service.Setting, error) {
+ return nil, service.ErrSettingNotFound
+}
+
+func (s *wechatOAuthSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
+ value, ok := s.values[key]
+ if !ok {
+ return "", service.ErrSettingNotFound
+ }
+ return value, nil
+}
+
+func (s *wechatOAuthSettingRepoStub) Set(context.Context, string, string) error {
+ return nil
+}
+
+func (s *wechatOAuthSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
+ result := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if value, ok := s.values[key]; ok {
+ result[key] = value
+ }
+ }
+ return result, nil
+}
+
+func (s *wechatOAuthSettingRepoStub) SetMultiple(context.Context, map[string]string) error {
+ return nil
+}
+
+func (s *wechatOAuthSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
+ result := make(map[string]string, len(s.values))
+ for key, value := range s.values {
+ result[key] = value
+ }
+ return result, nil
+}
+
+func (s *wechatOAuthSettingRepoStub) Delete(context.Context, string) error {
+ return nil
+}
+
+type wechatOAuthRefreshTokenCacheStub struct{}
+
+func (s *wechatOAuthRefreshTokenCacheStub) StoreRefreshToken(context.Context, string, *service.RefreshTokenData, time.Duration) error {
+ return nil
+}
+
+func (s *wechatOAuthRefreshTokenCacheStub) GetRefreshToken(context.Context, string) (*service.RefreshTokenData, error) {
+ return nil, service.ErrRefreshTokenNotFound
+}
+
+func (s *wechatOAuthRefreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error {
+ return nil
+}
+
+func (s *wechatOAuthRefreshTokenCacheStub) DeleteUserRefreshTokens(context.Context, int64) error {
+ return nil
+}
+
+func (s *wechatOAuthRefreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error {
+ return nil
+}
+
+func (s *wechatOAuthRefreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error {
+ return nil
+}
+
+func (s *wechatOAuthRefreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error {
+ return nil
+}
+
+func (s *wechatOAuthRefreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) {
+ return nil, nil
+}
+
+func (s *wechatOAuthRefreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) {
+ return nil, nil
+}
+
+func (s *wechatOAuthRefreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) {
+ return false, nil
+}
diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go
index 3659e79b..f44b3e3b 100644
--- a/backend/internal/handler/dto/settings.go
+++ b/backend/internal/handler/dto/settings.go
@@ -189,6 +189,7 @@ type PublicSettings struct {
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
+ WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"`
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"`
SoraClientEnabled bool `json:"sora_client_enabled"`
diff --git a/backend/internal/handler/payment_webhook_handler.go b/backend/internal/handler/payment_webhook_handler.go
index 8a83bfeb..9fdefa93 100644
--- a/backend/internal/handler/payment_webhook_handler.go
+++ b/backend/internal/handler/payment_webhook_handler.go
@@ -120,7 +120,7 @@ func (h *PaymentWebhookHandler) handleNotify(c *gin.Context, providerKey string)
// This allows looking up the correct provider instance before verification.
func extractOutTradeNo(rawBody, providerKey string) string {
switch providerKey {
- case payment.TypeEasyPay:
+ case payment.TypeEasyPay, payment.TypeAlipay:
values, err := url.ParseQuery(rawBody)
if err == nil {
return values.Get("out_trade_no")
diff --git a/backend/internal/handler/payment_webhook_handler_test.go b/backend/internal/handler/payment_webhook_handler_test.go
index bdef1766..6f448131 100644
--- a/backend/internal/handler/payment_webhook_handler_test.go
+++ b/backend/internal/handler/payment_webhook_handler_test.go
@@ -97,3 +97,37 @@ func TestWebhookConstants(t *testing.T) {
assert.Equal(t, 200, webhookLogTruncateLen)
})
}
+
+func TestExtractOutTradeNo(t *testing.T) {
+ tests := []struct {
+ name string
+ providerKey string
+ rawBody string
+ want string
+ }{
+ {
+ name: "easypay query payload",
+ providerKey: "easypay",
+ rawBody: "out_trade_no=sub2_123&trade_status=TRADE_SUCCESS",
+ want: "sub2_123",
+ },
+ {
+ name: "alipay query payload",
+ providerKey: "alipay",
+ rawBody: "notify_time=2026-04-20+12%3A00%3A00&out_trade_no=sub2_456",
+ want: "sub2_456",
+ },
+ {
+ name: "unknown provider",
+ providerKey: "wxpay",
+ rawBody: "{}",
+ want: "",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ assert.Equal(t, tt.want, extractOutTradeNo(tt.rawBody, tt.providerKey))
+ })
+ }
+}
diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go
index 1717b7a1..c7bc3e2a 100644
--- a/backend/internal/handler/setting_handler.go
+++ b/backend/internal/handler/setting_handler.go
@@ -56,6 +56,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
+ WeChatOAuthEnabled: settings.WeChatOAuthEnabled,
OIDCOAuthEnabled: settings.OIDCOAuthEnabled,
OIDCOAuthProviderName: settings.OIDCOAuthProviderName,
BackendModeEnabled: settings.BackendModeEnabled,
diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go
index 2535ea5e..904341d0 100644
--- a/backend/internal/handler/user_handler.go
+++ b/backend/internal/handler/user_handler.go
@@ -34,10 +34,16 @@ type ChangePasswordRequest struct {
// UpdateProfileRequest represents the update profile request payload
type UpdateProfileRequest struct {
Username *string `json:"username"`
+ AvatarURL *string `json:"avatar_url"`
BalanceNotifyEnabled *bool `json:"balance_notify_enabled"`
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"`
}
+type userProfileResponse struct {
+ dto.User
+ AvatarURL string `json:"avatar_url,omitempty"`
+}
+
// GetProfile handles getting user profile
// GET /api/v1/users/me
func (h *UserHandler) GetProfile(c *gin.Context) {
@@ -47,13 +53,13 @@ func (h *UserHandler) GetProfile(c *gin.Context) {
return
}
- userData, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
+ userData, err := h.userService.GetProfile(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
- response.Success(c, dto.UserFromService(userData))
+ response.Success(c, userProfileResponseFromService(userData))
}
// ChangePassword handles changing user password
@@ -101,6 +107,7 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
svcReq := service.UpdateProfileRequest{
Username: req.Username,
+ AvatarURL: req.AvatarURL,
BalanceNotifyEnabled: req.BalanceNotifyEnabled,
BalanceNotifyThreshold: req.BalanceNotifyThreshold,
}
@@ -110,7 +117,7 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
return
}
- response.Success(c, dto.UserFromService(updatedUser))
+ response.Success(c, userProfileResponseFromService(updatedUser))
}
// SendNotifyEmailCodeRequest represents the request to send notify email verification code
@@ -176,7 +183,7 @@ func (h *UserHandler) VerifyNotifyEmail(c *gin.Context) {
return
}
- response.Success(c, dto.UserFromService(updatedUser))
+ response.Success(c, userProfileResponseFromService(updatedUser))
}
// RemoveNotifyEmailRequest represents the request to remove a notify email
@@ -212,7 +219,7 @@ func (h *UserHandler) RemoveNotifyEmail(c *gin.Context) {
return
}
- response.Success(c, dto.UserFromService(updatedUser))
+ response.Success(c, userProfileResponseFromService(updatedUser))
}
// ToggleNotifyEmailRequest represents the request to toggle a notify email's disabled state
@@ -248,5 +255,16 @@ func (h *UserHandler) ToggleNotifyEmail(c *gin.Context) {
return
}
- response.Success(c, dto.UserFromService(updatedUser))
+ response.Success(c, userProfileResponseFromService(updatedUser))
+}
+
+func userProfileResponseFromService(user *service.User) userProfileResponse {
+ base := dto.UserFromService(user)
+ if base == nil {
+ return userProfileResponse{}
+ }
+ return userProfileResponse{
+ User: *base,
+ AvatarURL: user.AvatarURL,
+ }
}
diff --git a/backend/internal/handler/user_handler_test.go b/backend/internal/handler/user_handler_test.go
new file mode 100644
index 00000000..1973f59e
--- /dev/null
+++ b/backend/internal/handler/user_handler_test.go
@@ -0,0 +1,136 @@
+//go:build unit
+
+package handler
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+type userHandlerRepoStub struct {
+ user *service.User
+}
+
+func (s *userHandlerRepoStub) Create(context.Context, *service.User) error { return nil }
+func (s *userHandlerRepoStub) GetByID(context.Context, int64) (*service.User, error) {
+ cloned := *s.user
+ return &cloned, nil
+}
+func (s *userHandlerRepoStub) GetByEmail(context.Context, string) (*service.User, error) {
+ cloned := *s.user
+ return &cloned, nil
+}
+func (s *userHandlerRepoStub) GetFirstAdmin(context.Context) (*service.User, error) {
+ cloned := *s.user
+ return &cloned, nil
+}
+func (s *userHandlerRepoStub) Update(_ context.Context, user *service.User) error {
+ cloned := *user
+ s.user = &cloned
+ return nil
+}
+func (s *userHandlerRepoStub) Delete(context.Context, int64) error { return nil }
+func (s *userHandlerRepoStub) GetUserAvatar(context.Context, int64) (*service.UserAvatar, error) {
+ if s.user == nil || s.user.AvatarURL == "" {
+ return nil, nil
+ }
+ return &service.UserAvatar{
+ StorageProvider: s.user.AvatarSource,
+ URL: s.user.AvatarURL,
+ ContentType: s.user.AvatarMIME,
+ ByteSize: s.user.AvatarByteSize,
+ SHA256: s.user.AvatarSHA256,
+ }, nil
+}
+func (s *userHandlerRepoStub) UpsertUserAvatar(_ context.Context, _ int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
+ s.user.AvatarURL = input.URL
+ s.user.AvatarSource = input.StorageProvider
+ s.user.AvatarMIME = input.ContentType
+ s.user.AvatarByteSize = input.ByteSize
+ s.user.AvatarSHA256 = input.SHA256
+ return &service.UserAvatar{
+ StorageProvider: input.StorageProvider,
+ URL: input.URL,
+ ContentType: input.ContentType,
+ ByteSize: input.ByteSize,
+ SHA256: input.SHA256,
+ }, nil
+}
+func (s *userHandlerRepoStub) DeleteUserAvatar(context.Context, int64) error {
+ s.user.AvatarURL = ""
+ s.user.AvatarSource = ""
+ s.user.AvatarMIME = ""
+ s.user.AvatarByteSize = 0
+ s.user.AvatarSHA256 = ""
+ return nil
+}
+func (s *userHandlerRepoStub) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
+ return nil, nil, nil
+}
+func (s *userHandlerRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
+ return nil, nil, nil
+}
+func (s *userHandlerRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil }
+func (s *userHandlerRepoStub) DeductBalance(context.Context, int64, float64) error { return nil }
+func (s *userHandlerRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil }
+func (s *userHandlerRepoStub) ExistsByEmail(context.Context, string) (bool, error) { return false, nil }
+func (s *userHandlerRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
+ return 0, nil
+}
+func (s *userHandlerRepoStub) AddGroupToAllowedGroups(context.Context, int64, int64) error {
+ return nil
+}
+func (s *userHandlerRepoStub) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
+ return nil
+}
+func (s *userHandlerRepoStub) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
+func (s *userHandlerRepoStub) EnableTotp(context.Context, int64) error { return nil }
+func (s *userHandlerRepoStub) DisableTotp(context.Context, int64) error { return nil }
+
+func TestUserHandlerUpdateProfileReturnsAvatarURL(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ repo := &userHandlerRepoStub{
+ user: &service.User{
+ ID: 11,
+ Email: "handler-avatar@example.com",
+ Username: "handler-avatar",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ },
+ }
+ handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil)
+
+ body := []byte(`{"avatar_url":"https://cdn.example.com/avatar.png"}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/user", bytes.NewReader(body))
+ c.Request.Header.Set("Content-Type", "application/json")
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 11})
+
+ handler.UpdateProfile(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data struct {
+ AvatarURL string `json:"avatar_url"`
+ Username string `json:"username"`
+ } `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.Equal(t, "https://cdn.example.com/avatar.png", resp.Data.AvatarURL)
+ require.Equal(t, "handler-avatar", resp.Data.Username)
+}
diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go
index 38ea9bde..36d80309 100644
--- a/backend/internal/repository/api_key_repo.go
+++ b/backend/internal/repository/api_key_repo.go
@@ -149,6 +149,9 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
user.FieldBalanceNotifyThreshold,
user.FieldBalanceNotifyExtraEmails,
user.FieldTotalRecharged,
+ user.FieldSignupSource,
+ user.FieldLastLoginAt,
+ user.FieldLastActiveAt,
)
}).
WithGroup(func(q *dbent.GroupQuery) {
@@ -656,6 +659,9 @@ func userEntityToService(u *dbent.User) *service.User {
Balance: u.Balance,
Concurrency: u.Concurrency,
Status: u.Status,
+ SignupSource: u.SignupSource,
+ LastLoginAt: u.LastLoginAt,
+ LastActiveAt: u.LastActiveAt,
TotpSecretEncrypted: u.TotpSecretEncrypted,
TotpEnabled: u.TotpEnabled,
TotpEnabledAt: u.TotpEnabledAt,
diff --git a/backend/internal/repository/auth_identity_migration_report.go b/backend/internal/repository/auth_identity_migration_report.go
new file mode 100644
index 00000000..70f298c1
--- /dev/null
+++ b/backend/internal/repository/auth_identity_migration_report.go
@@ -0,0 +1,148 @@
+package repository
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+ "fmt"
+ "strings"
+ "time"
+)
+
+type AuthIdentityMigrationReport struct {
+ ID int64
+ ReportType string
+ ReportKey string
+ Details map[string]any
+ CreatedAt time.Time
+}
+
+type AuthIdentityMigrationReportQuery struct {
+ ReportType string
+ Limit int
+ Offset int
+}
+
+type AuthIdentityMigrationReportSummary struct {
+ Total int64
+ ByType map[string]int64
+}
+
+func (r *userRepository) ListAuthIdentityMigrationReports(ctx context.Context, query AuthIdentityMigrationReportQuery) ([]AuthIdentityMigrationReport, error) {
+ exec := txAwareSQLExecutor(ctx, r.sql, r.client)
+ if exec == nil {
+ return nil, fmt.Errorf("sql executor is not configured")
+ }
+
+ limit := query.Limit
+ if limit <= 0 {
+ limit = 100
+ }
+ rows, err := exec.QueryContext(ctx, `
+SELECT id, report_type, report_key, details, created_at
+FROM auth_identity_migration_reports
+WHERE ($1 = '' OR report_type = $1)
+ORDER BY created_at DESC, id DESC
+LIMIT $2 OFFSET $3`,
+ strings.TrimSpace(query.ReportType),
+ limit,
+ query.Offset,
+ )
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = rows.Close() }()
+
+ reports := make([]AuthIdentityMigrationReport, 0)
+ for rows.Next() {
+ report, scanErr := scanAuthIdentityMigrationReport(rows)
+ if scanErr != nil {
+ return nil, scanErr
+ }
+ reports = append(reports, report)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return reports, nil
+}
+
+func (r *userRepository) GetAuthIdentityMigrationReport(ctx context.Context, reportType, reportKey string) (*AuthIdentityMigrationReport, error) {
+ exec := txAwareSQLExecutor(ctx, r.sql, r.client)
+ if exec == nil {
+ return nil, fmt.Errorf("sql executor is not configured")
+ }
+
+ rows, err := exec.QueryContext(ctx, `
+SELECT id, report_type, report_key, details, created_at
+FROM auth_identity_migration_reports
+WHERE report_type = $1 AND report_key = $2
+LIMIT 1`,
+ strings.TrimSpace(reportType),
+ strings.TrimSpace(reportKey),
+ )
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = rows.Close() }()
+
+ if !rows.Next() {
+ return nil, sql.ErrNoRows
+ }
+ report, err := scanAuthIdentityMigrationReport(rows)
+ if err != nil {
+ return nil, err
+ }
+ return &report, rows.Err()
+}
+
+func (r *userRepository) SummarizeAuthIdentityMigrationReports(ctx context.Context) (*AuthIdentityMigrationReportSummary, error) {
+ exec := txAwareSQLExecutor(ctx, r.sql, r.client)
+ if exec == nil {
+ return nil, fmt.Errorf("sql executor is not configured")
+ }
+
+ rows, err := exec.QueryContext(ctx, `
+SELECT report_type, COUNT(*)
+FROM auth_identity_migration_reports
+GROUP BY report_type
+ORDER BY report_type ASC`)
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = rows.Close() }()
+
+ summary := &AuthIdentityMigrationReportSummary{
+ ByType: make(map[string]int64),
+ }
+ for rows.Next() {
+ var reportType string
+ var count int64
+ if err := rows.Scan(&reportType, &count); err != nil {
+ return nil, err
+ }
+ summary.ByType[reportType] = count
+ summary.Total += count
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return summary, nil
+}
+
+func scanAuthIdentityMigrationReport(scanner interface{ Scan(dest ...any) error }) (AuthIdentityMigrationReport, error) {
+ var (
+ report AuthIdentityMigrationReport
+ details []byte
+ )
+ if err := scanner.Scan(&report.ID, &report.ReportType, &report.ReportKey, &details, &report.CreatedAt); err != nil {
+ return AuthIdentityMigrationReport{}, err
+ }
+ report.Details = map[string]any{}
+ if len(details) > 0 {
+ if err := json.Unmarshal(details, &report.Details); err != nil {
+ return AuthIdentityMigrationReport{}, err
+ }
+ }
+ return report, nil
+}
diff --git a/backend/internal/repository/user_profile_identity_repo.go b/backend/internal/repository/user_profile_identity_repo.go
new file mode 100644
index 00000000..4ecae4a4
--- /dev/null
+++ b/backend/internal/repository/user_profile_identity_repo.go
@@ -0,0 +1,544 @@
+package repository
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "reflect"
+ "strings"
+ "time"
+ "unsafe"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+)
+
+var (
+ ErrAuthIdentityOwnershipConflict = infraerrors.Conflict(
+ "AUTH_IDENTITY_OWNERSHIP_CONFLICT",
+ "auth identity already belongs to another user",
+ )
+ ErrAuthIdentityChannelOwnershipConflict = infraerrors.Conflict(
+ "AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT",
+ "auth identity channel already belongs to another user",
+ )
+)
+
+type ProviderGrantReason string
+
+const (
+ ProviderGrantReasonSignup ProviderGrantReason = "signup"
+ ProviderGrantReasonFirstBind ProviderGrantReason = "first_bind"
+)
+
+type AuthIdentityKey struct {
+ ProviderType string
+ ProviderKey string
+ ProviderSubject string
+}
+
+type AuthIdentityChannelKey struct {
+ ProviderType string
+ ProviderKey string
+ Channel string
+ ChannelAppID string
+ ChannelSubject string
+}
+
+type CreateAuthIdentityInput struct {
+ UserID int64
+ Canonical AuthIdentityKey
+ Channel *AuthIdentityChannelKey
+ Issuer *string
+ VerifiedAt *time.Time
+ Metadata map[string]any
+ ChannelMetadata map[string]any
+}
+
+type BindAuthIdentityInput = CreateAuthIdentityInput
+
+type CreateAuthIdentityResult struct {
+ Identity *dbent.AuthIdentity
+ Channel *dbent.AuthIdentityChannel
+}
+
+func (r *CreateAuthIdentityResult) IdentityRef() AuthIdentityKey {
+ if r == nil || r.Identity == nil {
+ return AuthIdentityKey{}
+ }
+ return AuthIdentityKey{
+ ProviderType: r.Identity.ProviderType,
+ ProviderKey: r.Identity.ProviderKey,
+ ProviderSubject: r.Identity.ProviderSubject,
+ }
+}
+
+func (r *CreateAuthIdentityResult) ChannelRef() *AuthIdentityChannelKey {
+ if r == nil || r.Channel == nil {
+ return nil
+ }
+ return &AuthIdentityChannelKey{
+ ProviderType: r.Channel.ProviderType,
+ ProviderKey: r.Channel.ProviderKey,
+ Channel: r.Channel.Channel,
+ ChannelAppID: r.Channel.ChannelAppID,
+ ChannelSubject: r.Channel.ChannelSubject,
+ }
+}
+
+type UserAuthIdentityLookup struct {
+ User *dbent.User
+ Identity *dbent.AuthIdentity
+ Channel *dbent.AuthIdentityChannel
+}
+
+type ProviderGrantRecordInput struct {
+ UserID int64
+ ProviderType string
+ GrantReason ProviderGrantReason
+}
+
+type IdentityAdoptionDecisionInput struct {
+ PendingAuthSessionID int64
+ IdentityID *int64
+ AdoptDisplayName bool
+ AdoptAvatar bool
+}
+
+type sqlQueryExecutor interface {
+ ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
+ QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
+}
+
+func (r *userRepository) WithUserProfileIdentityTx(ctx context.Context, fn func(txCtx context.Context) error) error {
+ if dbent.TxFromContext(ctx) != nil {
+ return fn(ctx)
+ }
+
+ tx, err := r.client.Tx(ctx)
+ if err != nil {
+ return err
+ }
+ defer func() { _ = tx.Rollback() }()
+
+ txCtx := dbent.NewTxContext(ctx, tx)
+ if err := fn(txCtx); err != nil {
+ return err
+ }
+ return tx.Commit()
+}
+
+func (r *userRepository) CreateAuthIdentity(ctx context.Context, input CreateAuthIdentityInput) (*CreateAuthIdentityResult, error) {
+ client := clientFromContext(ctx, r.client)
+
+ create := client.AuthIdentity.Create().
+ SetUserID(input.UserID).
+ SetProviderType(strings.TrimSpace(input.Canonical.ProviderType)).
+ SetProviderKey(strings.TrimSpace(input.Canonical.ProviderKey)).
+ SetProviderSubject(strings.TrimSpace(input.Canonical.ProviderSubject)).
+ SetMetadata(copyMetadata(input.Metadata)).
+ SetNillableIssuer(input.Issuer).
+ SetNillableVerifiedAt(input.VerifiedAt)
+
+ identity, err := create.Save(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ var channel *dbent.AuthIdentityChannel
+ if input.Channel != nil {
+ channel, err = client.AuthIdentityChannel.Create().
+ SetIdentityID(identity.ID).
+ SetProviderType(strings.TrimSpace(input.Channel.ProviderType)).
+ SetProviderKey(strings.TrimSpace(input.Channel.ProviderKey)).
+ SetChannel(strings.TrimSpace(input.Channel.Channel)).
+ SetChannelAppID(strings.TrimSpace(input.Channel.ChannelAppID)).
+ SetChannelSubject(strings.TrimSpace(input.Channel.ChannelSubject)).
+ SetMetadata(copyMetadata(input.ChannelMetadata)).
+ Save(ctx)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return &CreateAuthIdentityResult{Identity: identity, Channel: channel}, nil
+}
+
+func (r *userRepository) GetUserByCanonicalIdentity(ctx context.Context, key AuthIdentityKey) (*UserAuthIdentityLookup, error) {
+ identity, err := clientFromContext(ctx, r.client).AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ(strings.TrimSpace(key.ProviderType)),
+ authidentity.ProviderKeyEQ(strings.TrimSpace(key.ProviderKey)),
+ authidentity.ProviderSubjectEQ(strings.TrimSpace(key.ProviderSubject)),
+ ).
+ WithUser().
+ Only(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ return &UserAuthIdentityLookup{
+ User: identity.Edges.User,
+ Identity: identity,
+ }, nil
+}
+
+func (r *userRepository) GetUserByChannelIdentity(ctx context.Context, key AuthIdentityChannelKey) (*UserAuthIdentityLookup, error) {
+ channel, err := clientFromContext(ctx, r.client).AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ(strings.TrimSpace(key.ProviderType)),
+ authidentitychannel.ProviderKeyEQ(strings.TrimSpace(key.ProviderKey)),
+ authidentitychannel.ChannelEQ(strings.TrimSpace(key.Channel)),
+ authidentitychannel.ChannelAppIDEQ(strings.TrimSpace(key.ChannelAppID)),
+ authidentitychannel.ChannelSubjectEQ(strings.TrimSpace(key.ChannelSubject)),
+ ).
+ WithIdentity(func(q *dbent.AuthIdentityQuery) {
+ q.WithUser()
+ }).
+ Only(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ return &UserAuthIdentityLookup{
+ User: channel.Edges.Identity.Edges.User,
+ Identity: channel.Edges.Identity,
+ Channel: channel,
+ }, nil
+}
+
+func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindAuthIdentityInput) (*CreateAuthIdentityResult, error) {
+ var result *CreateAuthIdentityResult
+ err := r.WithUserProfileIdentityTx(ctx, func(txCtx context.Context) error {
+ client := clientFromContext(txCtx, r.client)
+ canonical := input.Canonical
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ(strings.TrimSpace(canonical.ProviderType)),
+ authidentity.ProviderKeyEQ(strings.TrimSpace(canonical.ProviderKey)),
+ authidentity.ProviderSubjectEQ(strings.TrimSpace(canonical.ProviderSubject)),
+ ).
+ Only(txCtx)
+ if err != nil && !dbent.IsNotFound(err) {
+ return err
+ }
+ if identity != nil && identity.UserID != input.UserID {
+ return ErrAuthIdentityOwnershipConflict
+ }
+ if identity == nil {
+ identity, err = client.AuthIdentity.Create().
+ SetUserID(input.UserID).
+ SetProviderType(strings.TrimSpace(canonical.ProviderType)).
+ SetProviderKey(strings.TrimSpace(canonical.ProviderKey)).
+ SetProviderSubject(strings.TrimSpace(canonical.ProviderSubject)).
+ SetMetadata(copyMetadata(input.Metadata)).
+ SetNillableIssuer(input.Issuer).
+ SetNillableVerifiedAt(input.VerifiedAt).
+ Save(txCtx)
+ if err != nil {
+ return err
+ }
+ } else {
+ update := client.AuthIdentity.UpdateOneID(identity.ID)
+ if input.Metadata != nil {
+ update = update.SetMetadata(copyMetadata(input.Metadata))
+ }
+ if input.Issuer != nil {
+ update = update.SetIssuer(strings.TrimSpace(*input.Issuer))
+ }
+ if input.VerifiedAt != nil {
+ update = update.SetVerifiedAt(*input.VerifiedAt)
+ }
+ identity, err = update.Save(txCtx)
+ if err != nil {
+ return err
+ }
+ }
+
+ var channel *dbent.AuthIdentityChannel
+ if input.Channel != nil {
+ channel, err = client.AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ(strings.TrimSpace(input.Channel.ProviderType)),
+ authidentitychannel.ProviderKeyEQ(strings.TrimSpace(input.Channel.ProviderKey)),
+ authidentitychannel.ChannelEQ(strings.TrimSpace(input.Channel.Channel)),
+ authidentitychannel.ChannelAppIDEQ(strings.TrimSpace(input.Channel.ChannelAppID)),
+ authidentitychannel.ChannelSubjectEQ(strings.TrimSpace(input.Channel.ChannelSubject)),
+ ).
+ WithIdentity().
+ Only(txCtx)
+ if err != nil && !dbent.IsNotFound(err) {
+ return err
+ }
+ if channel != nil && channel.Edges.Identity != nil && channel.Edges.Identity.UserID != input.UserID {
+ return ErrAuthIdentityChannelOwnershipConflict
+ }
+ if channel == nil {
+ channel, err = client.AuthIdentityChannel.Create().
+ SetIdentityID(identity.ID).
+ SetProviderType(strings.TrimSpace(input.Channel.ProviderType)).
+ SetProviderKey(strings.TrimSpace(input.Channel.ProviderKey)).
+ SetChannel(strings.TrimSpace(input.Channel.Channel)).
+ SetChannelAppID(strings.TrimSpace(input.Channel.ChannelAppID)).
+ SetChannelSubject(strings.TrimSpace(input.Channel.ChannelSubject)).
+ SetMetadata(copyMetadata(input.ChannelMetadata)).
+ Save(txCtx)
+ if err != nil {
+ return err
+ }
+ } else {
+ update := client.AuthIdentityChannel.UpdateOneID(channel.ID).
+ SetIdentityID(identity.ID)
+ if input.ChannelMetadata != nil {
+ update = update.SetMetadata(copyMetadata(input.ChannelMetadata))
+ }
+ channel, err = update.Save(txCtx)
+ if err != nil {
+ return err
+ }
+ }
+ }
+
+ result = &CreateAuthIdentityResult{Identity: identity, Channel: channel}
+ return nil
+ })
+ if err != nil {
+ return nil, err
+ }
+ return result, nil
+}
+
+func (r *userRepository) RecordProviderGrant(ctx context.Context, input ProviderGrantRecordInput) (bool, error) {
+ exec := txAwareSQLExecutor(ctx, r.sql, r.client)
+ if exec == nil {
+ return false, fmt.Errorf("sql executor is not configured")
+ }
+
+ result, err := exec.ExecContext(ctx, `
+INSERT INTO user_provider_default_grants (user_id, provider_type, grant_reason)
+VALUES ($1, $2, $3)
+ON CONFLICT (user_id, provider_type, grant_reason) DO NOTHING`,
+ input.UserID,
+ strings.TrimSpace(input.ProviderType),
+ string(input.GrantReason),
+ )
+ if err != nil {
+ return false, err
+ }
+ affected, err := result.RowsAffected()
+ if err != nil {
+ return false, err
+ }
+ return affected > 0, nil
+}
+
+func (r *userRepository) UpsertIdentityAdoptionDecision(ctx context.Context, input IdentityAdoptionDecisionInput) (*dbent.IdentityAdoptionDecision, error) {
+ client := clientFromContext(ctx, r.client)
+ current, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(input.PendingAuthSessionID)).
+ Only(ctx)
+ if err != nil && !dbent.IsNotFound(err) {
+ return nil, err
+ }
+ now := time.Now().UTC()
+ if current == nil {
+ create := client.IdentityAdoptionDecision.Create().
+ SetPendingAuthSessionID(input.PendingAuthSessionID).
+ SetAdoptDisplayName(input.AdoptDisplayName).
+ SetAdoptAvatar(input.AdoptAvatar).
+ SetDecidedAt(now)
+ if input.IdentityID != nil {
+ create = create.SetIdentityID(*input.IdentityID)
+ }
+ return create.Save(ctx)
+ }
+
+ update := client.IdentityAdoptionDecision.UpdateOneID(current.ID).
+ SetAdoptDisplayName(input.AdoptDisplayName).
+ SetAdoptAvatar(input.AdoptAvatar)
+ if input.IdentityID != nil {
+ update = update.SetIdentityID(*input.IdentityID)
+ }
+ return update.Save(ctx)
+}
+
+func (r *userRepository) GetIdentityAdoptionDecisionByPendingAuthSessionID(ctx context.Context, pendingAuthSessionID int64) (*dbent.IdentityAdoptionDecision, error) {
+ return clientFromContext(ctx, r.client).IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(pendingAuthSessionID)).
+ Only(ctx)
+}
+
+func (r *userRepository) UpdateUserLastLoginAt(ctx context.Context, userID int64, loginAt time.Time) error {
+ _, err := clientFromContext(ctx, r.client).User.UpdateOneID(userID).
+ SetLastLoginAt(loginAt).
+ Save(ctx)
+ return err
+}
+
+func (r *userRepository) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error {
+ _, err := clientFromContext(ctx, r.client).User.UpdateOneID(userID).
+ SetLastActiveAt(activeAt).
+ Save(ctx)
+ return err
+}
+
+func (r *userRepository) GetUserAvatar(ctx context.Context, userID int64) (*service.UserAvatar, error) {
+ exec, err := r.userProfileIdentitySQL(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ rows, err := exec.QueryContext(ctx, `
+SELECT storage_provider, storage_key, url, content_type, byte_size, sha256
+FROM user_avatars
+WHERE user_id = $1`, userID)
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = rows.Close() }()
+
+ if !rows.Next() {
+ return nil, rows.Err()
+ }
+
+ var avatar service.UserAvatar
+ if err := rows.Scan(
+ &avatar.StorageProvider,
+ &avatar.StorageKey,
+ &avatar.URL,
+ &avatar.ContentType,
+ &avatar.ByteSize,
+ &avatar.SHA256,
+ ); err != nil {
+ return nil, err
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return &avatar, nil
+}
+
+func (r *userRepository) UpsertUserAvatar(ctx context.Context, userID int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
+ exec, err := r.userProfileIdentitySQL(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ _, err = exec.ExecContext(ctx, `
+INSERT INTO user_avatars (user_id, storage_provider, storage_key, url, content_type, byte_size, sha256, updated_at)
+VALUES ($1, $2, $3, $4, $5, $6, $7, NOW())
+ON CONFLICT (user_id) DO UPDATE SET
+ storage_provider = EXCLUDED.storage_provider,
+ storage_key = EXCLUDED.storage_key,
+ url = EXCLUDED.url,
+ content_type = EXCLUDED.content_type,
+ byte_size = EXCLUDED.byte_size,
+ sha256 = EXCLUDED.sha256,
+ updated_at = NOW()`,
+ userID,
+ strings.TrimSpace(input.StorageProvider),
+ strings.TrimSpace(input.StorageKey),
+ strings.TrimSpace(input.URL),
+ strings.TrimSpace(input.ContentType),
+ input.ByteSize,
+ strings.TrimSpace(input.SHA256),
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ return &service.UserAvatar{
+ StorageProvider: strings.TrimSpace(input.StorageProvider),
+ StorageKey: strings.TrimSpace(input.StorageKey),
+ URL: strings.TrimSpace(input.URL),
+ ContentType: strings.TrimSpace(input.ContentType),
+ ByteSize: input.ByteSize,
+ SHA256: strings.TrimSpace(input.SHA256),
+ }, nil
+}
+
+func (r *userRepository) DeleteUserAvatar(ctx context.Context, userID int64) error {
+ exec, err := r.userProfileIdentitySQL(ctx)
+ if err != nil {
+ return err
+ }
+ _, err = exec.ExecContext(ctx, `DELETE FROM user_avatars WHERE user_id = $1`, userID)
+ return err
+}
+
+func (r *userRepository) attachUserAvatar(ctx context.Context, user *service.User) error {
+ if user == nil {
+ return nil
+ }
+
+ avatar, err := r.GetUserAvatar(ctx, user.ID)
+ if err != nil {
+ return err
+ }
+ if avatar == nil {
+ return nil
+ }
+
+ user.AvatarURL = avatar.URL
+ user.AvatarSource = avatar.StorageProvider
+ user.AvatarMIME = avatar.ContentType
+ user.AvatarByteSize = avatar.ByteSize
+ user.AvatarSHA256 = avatar.SHA256
+ return nil
+}
+
+func copyMetadata(in map[string]any) map[string]any {
+ if len(in) == 0 {
+ return map[string]any{}
+ }
+ out := make(map[string]any, len(in))
+ for k, v := range in {
+ out[k] = v
+ }
+ return out
+}
+
+func txAwareSQLExecutor(ctx context.Context, fallback sqlExecutor, client *dbent.Client) sqlQueryExecutor {
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ if exec := sqlExecutorFromEntClient(tx.Client()); exec != nil {
+ return exec
+ }
+ }
+ if fallback != nil {
+ return fallback
+ }
+ return sqlExecutorFromEntClient(client)
+}
+
+func (r *userRepository) userProfileIdentitySQL(ctx context.Context) (sqlQueryExecutor, error) {
+ exec := txAwareSQLExecutor(ctx, r.sql, r.client)
+ if exec == nil {
+ return nil, fmt.Errorf("sql executor is not configured")
+ }
+ return exec, nil
+}
+
+func sqlExecutorFromEntClient(client *dbent.Client) sqlQueryExecutor {
+ if client == nil {
+ return nil
+ }
+
+ clientValue := reflect.ValueOf(client).Elem()
+ configValue := clientValue.FieldByName("config")
+ driverValue := configValue.FieldByName("driver")
+ if !driverValue.IsValid() {
+ return nil
+ }
+
+ driver := reflect.NewAt(driverValue.Type(), unsafe.Pointer(driverValue.UnsafeAddr())).Elem().Interface()
+ exec, ok := driver.(sqlQueryExecutor)
+ if !ok {
+ return nil
+ }
+ return exec
+}
diff --git a/backend/internal/repository/user_profile_identity_repo_contract_test.go b/backend/internal/repository/user_profile_identity_repo_contract_test.go
new file mode 100644
index 00000000..19022ec1
--- /dev/null
+++ b/backend/internal/repository/user_profile_identity_repo_contract_test.go
@@ -0,0 +1,428 @@
+//go:build integration
+
+package repository
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/stretchr/testify/suite"
+)
+
+type UserProfileIdentityRepoSuite struct {
+ suite.Suite
+ ctx context.Context
+ client *dbent.Client
+ repo *userRepository
+}
+
+func TestUserProfileIdentityRepoSuite(t *testing.T) {
+ suite.Run(t, new(UserProfileIdentityRepoSuite))
+}
+
+func (s *UserProfileIdentityRepoSuite) SetupTest() {
+ s.ctx = context.Background()
+ s.client = testEntClient(s.T())
+ s.repo = newUserRepositoryWithSQL(s.client, integrationDB)
+
+ _, err := integrationDB.ExecContext(s.ctx, `
+TRUNCATE TABLE
+ identity_adoption_decisions,
+ auth_identity_channels,
+ auth_identities,
+ pending_auth_sessions,
+ auth_identity_migration_reports,
+ user_provider_default_grants,
+ user_avatars
+RESTART IDENTITY`)
+ s.Require().NoError(err)
+}
+
+func (s *UserProfileIdentityRepoSuite) mustCreateUser(label string) *dbent.User {
+ s.T().Helper()
+
+ user, err := s.client.User.Create().
+ SetEmail(fmt.Sprintf("%s-%d@example.com", label, time.Now().UnixNano())).
+ SetPasswordHash("test-password-hash").
+ SetRole("user").
+ SetStatus("active").
+ Save(s.ctx)
+ s.Require().NoError(err)
+ return user
+}
+
+func (s *UserProfileIdentityRepoSuite) mustCreatePendingAuthSession(key AuthIdentityKey) *dbent.PendingAuthSession {
+ s.T().Helper()
+
+ session, err := s.client.PendingAuthSession.Create().
+ SetSessionToken(fmt.Sprintf("pending-%d", time.Now().UnixNano())).
+ SetIntent("bind_current_user").
+ SetProviderType(key.ProviderType).
+ SetProviderKey(key.ProviderKey).
+ SetProviderSubject(key.ProviderSubject).
+ SetExpiresAt(time.Now().UTC().Add(15 * time.Minute)).
+ SetUpstreamIdentityClaims(map[string]any{"provider_subject": key.ProviderSubject}).
+ SetLocalFlowState(map[string]any{"step": "pending"}).
+ Save(s.ctx)
+ s.Require().NoError(err)
+ return session
+}
+
+func (s *UserProfileIdentityRepoSuite) TestCreateAndLookupCanonicalAndChannelIdentity() {
+ user := s.mustCreateUser("canonical-channel")
+
+ verifiedAt := time.Now().UTC().Truncate(time.Second)
+ created, err := s.repo.CreateAuthIdentity(s.ctx, CreateAuthIdentityInput{
+ UserID: user.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-open",
+ ProviderSubject: "union-123",
+ },
+ Channel: &AuthIdentityChannelKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-open",
+ Channel: "mp",
+ ChannelAppID: "wx-app",
+ ChannelSubject: "openid-123",
+ },
+ Issuer: stringPtr("https://issuer.example"),
+ VerifiedAt: &verifiedAt,
+ Metadata: map[string]any{"unionid": "union-123"},
+ ChannelMetadata: map[string]any{"openid": "openid-123"},
+ })
+ s.Require().NoError(err)
+ s.Require().NotNil(created.Identity)
+ s.Require().NotNil(created.Channel)
+
+ canonical, err := s.repo.GetUserByCanonicalIdentity(s.ctx, created.IdentityRef())
+ s.Require().NoError(err)
+ s.Require().Equal(user.ID, canonical.User.ID)
+ s.Require().Equal(created.Identity.ID, canonical.Identity.ID)
+ s.Require().Equal("union-123", canonical.Identity.ProviderSubject)
+
+ channel, err := s.repo.GetUserByChannelIdentity(s.ctx, *created.ChannelRef())
+ s.Require().NoError(err)
+ s.Require().Equal(user.ID, channel.User.ID)
+ s.Require().Equal(created.Identity.ID, channel.Identity.ID)
+ s.Require().Equal(created.Channel.ID, channel.Channel.ID)
+}
+
+func (s *UserProfileIdentityRepoSuite) TestBindAuthIdentityToUser_IsIdempotentAndRejectsOtherOwners() {
+ owner := s.mustCreateUser("owner")
+ other := s.mustCreateUser("other")
+
+ first, err := s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
+ UserID: owner.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo-main",
+ ProviderSubject: "subject-1",
+ },
+ Channel: &AuthIdentityChannelKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo-main",
+ Channel: "oauth",
+ ChannelAppID: "linuxdo-web",
+ ChannelSubject: "subject-1",
+ },
+ Metadata: map[string]any{"username": "first"},
+ ChannelMetadata: map[string]any{"scope": "read"},
+ })
+ s.Require().NoError(err)
+
+ second, err := s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
+ UserID: owner.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo-main",
+ ProviderSubject: "subject-1",
+ },
+ Channel: &AuthIdentityChannelKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo-main",
+ Channel: "oauth",
+ ChannelAppID: "linuxdo-web",
+ ChannelSubject: "subject-1",
+ },
+ Metadata: map[string]any{"username": "second"},
+ ChannelMetadata: map[string]any{"scope": "write"},
+ })
+ s.Require().NoError(err)
+ s.Require().Equal(first.Identity.ID, second.Identity.ID)
+ s.Require().Equal(first.Channel.ID, second.Channel.ID)
+ s.Require().Equal("second", second.Identity.Metadata["username"])
+ s.Require().Equal("write", second.Channel.Metadata["scope"])
+
+ _, err = s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
+ UserID: other.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo-main",
+ ProviderSubject: "subject-1",
+ },
+ })
+ s.Require().ErrorIs(err, ErrAuthIdentityOwnershipConflict)
+
+ _, err = s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
+ UserID: other.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo-main",
+ ProviderSubject: "subject-2",
+ },
+ Channel: &AuthIdentityChannelKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo-main",
+ Channel: "oauth",
+ ChannelAppID: "linuxdo-web",
+ ChannelSubject: "subject-1",
+ },
+ })
+ s.Require().ErrorIs(err, ErrAuthIdentityChannelOwnershipConflict)
+}
+
+func (s *UserProfileIdentityRepoSuite) TestWithUserProfileIdentityTx_RollsBackIdentityAndGrantOnError() {
+ user := s.mustCreateUser("tx-rollback")
+ expectedErr := errors.New("rollback")
+
+ err := s.repo.WithUserProfileIdentityTx(s.ctx, func(txCtx context.Context) error {
+ _, err := s.repo.CreateAuthIdentity(txCtx, CreateAuthIdentityInput{
+ UserID: user.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "oidc",
+ ProviderKey: "https://issuer.example",
+ ProviderSubject: "subject-rollback",
+ },
+ })
+ s.Require().NoError(err)
+
+ inserted, err := s.repo.RecordProviderGrant(txCtx, ProviderGrantRecordInput{
+ UserID: user.ID,
+ ProviderType: "oidc",
+ GrantReason: ProviderGrantReasonFirstBind,
+ })
+ s.Require().NoError(err)
+ s.Require().True(inserted)
+ return expectedErr
+ })
+ s.Require().ErrorIs(err, expectedErr)
+
+ _, err = s.repo.GetUserByCanonicalIdentity(s.ctx, AuthIdentityKey{
+ ProviderType: "oidc",
+ ProviderKey: "https://issuer.example",
+ ProviderSubject: "subject-rollback",
+ })
+ s.Require().True(dbent.IsNotFound(err))
+
+ var count int
+ s.Require().NoError(integrationDB.QueryRowContext(s.ctx, `
+SELECT COUNT(*)
+FROM user_provider_default_grants
+WHERE user_id = $1 AND provider_type = $2 AND grant_reason = $3`,
+ user.ID,
+ "oidc",
+ string(ProviderGrantReasonFirstBind),
+ ).Scan(&count))
+ s.Require().Zero(count)
+}
+
+func (s *UserProfileIdentityRepoSuite) TestRecordProviderGrant_IsIdempotentPerReason() {
+ user := s.mustCreateUser("grant")
+
+ inserted, err := s.repo.RecordProviderGrant(s.ctx, ProviderGrantRecordInput{
+ UserID: user.ID,
+ ProviderType: "wechat",
+ GrantReason: ProviderGrantReasonFirstBind,
+ })
+ s.Require().NoError(err)
+ s.Require().True(inserted)
+
+ inserted, err = s.repo.RecordProviderGrant(s.ctx, ProviderGrantRecordInput{
+ UserID: user.ID,
+ ProviderType: "wechat",
+ GrantReason: ProviderGrantReasonFirstBind,
+ })
+ s.Require().NoError(err)
+ s.Require().False(inserted)
+
+ inserted, err = s.repo.RecordProviderGrant(s.ctx, ProviderGrantRecordInput{
+ UserID: user.ID,
+ ProviderType: "wechat",
+ GrantReason: ProviderGrantReasonSignup,
+ })
+ s.Require().NoError(err)
+ s.Require().True(inserted)
+
+ var count int
+ s.Require().NoError(integrationDB.QueryRowContext(s.ctx, `
+SELECT COUNT(*)
+FROM user_provider_default_grants
+WHERE user_id = $1 AND provider_type = $2`,
+ user.ID,
+ "wechat",
+ ).Scan(&count))
+ s.Require().Equal(2, count)
+}
+
+func (s *UserProfileIdentityRepoSuite) TestUpsertIdentityAdoptionDecision_PersistsAndLinksIdentity() {
+ user := s.mustCreateUser("adoption")
+ identity, err := s.repo.CreateAuthIdentity(s.ctx, CreateAuthIdentityInput{
+ UserID: user.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-open",
+ ProviderSubject: "union-adoption",
+ },
+ })
+ s.Require().NoError(err)
+
+ session := s.mustCreatePendingAuthSession(identity.IdentityRef())
+
+ first, err := s.repo.UpsertIdentityAdoptionDecision(s.ctx, IdentityAdoptionDecisionInput{
+ PendingAuthSessionID: session.ID,
+ AdoptDisplayName: true,
+ AdoptAvatar: false,
+ })
+ s.Require().NoError(err)
+ s.Require().True(first.AdoptDisplayName)
+ s.Require().False(first.AdoptAvatar)
+ s.Require().Nil(first.IdentityID)
+
+ second, err := s.repo.UpsertIdentityAdoptionDecision(s.ctx, IdentityAdoptionDecisionInput{
+ PendingAuthSessionID: session.ID,
+ IdentityID: &identity.Identity.ID,
+ AdoptDisplayName: true,
+ AdoptAvatar: true,
+ })
+ s.Require().NoError(err)
+ s.Require().Equal(first.ID, second.ID)
+ s.Require().NotNil(second.IdentityID)
+ s.Require().Equal(identity.Identity.ID, *second.IdentityID)
+ s.Require().True(second.AdoptAvatar)
+
+ loaded, err := s.repo.GetIdentityAdoptionDecisionByPendingAuthSessionID(s.ctx, session.ID)
+ s.Require().NoError(err)
+ s.Require().Equal(second.ID, loaded.ID)
+ s.Require().Equal(identity.Identity.ID, *loaded.IdentityID)
+}
+
+func (s *UserProfileIdentityRepoSuite) TestUserAvatarCRUDAndUserLookup() {
+ user := s.mustCreateUser("avatar")
+
+ inlineAvatar, err := s.repo.UpsertUserAvatar(s.ctx, user.ID, service.UpsertUserAvatarInput{
+ StorageProvider: "inline",
+ URL: "data:image/png;base64,QUJD",
+ ContentType: "image/png",
+ ByteSize: 3,
+ SHA256: "902fbdd2b1df0c4f70b4a5d23525e932",
+ })
+ s.Require().NoError(err)
+ s.Require().Equal("inline", inlineAvatar.StorageProvider)
+ s.Require().Equal("data:image/png;base64,QUJD", inlineAvatar.URL)
+
+ loadedAvatar, err := s.repo.GetUserAvatar(s.ctx, user.ID)
+ s.Require().NoError(err)
+ s.Require().NotNil(loadedAvatar)
+ s.Require().Equal("image/png", loadedAvatar.ContentType)
+ s.Require().Equal(3, loadedAvatar.ByteSize)
+
+ _, err = s.repo.UpsertUserAvatar(s.ctx, user.ID, service.UpsertUserAvatarInput{
+ StorageProvider: "remote_url",
+ URL: "https://cdn.example.com/avatar.png",
+ })
+ s.Require().NoError(err)
+
+ loadedAvatar, err = s.repo.GetUserAvatar(s.ctx, user.ID)
+ s.Require().NoError(err)
+ s.Require().NotNil(loadedAvatar)
+ s.Require().Equal("remote_url", loadedAvatar.StorageProvider)
+ s.Require().Equal("https://cdn.example.com/avatar.png", loadedAvatar.URL)
+ s.Require().Zero(loadedAvatar.ByteSize)
+
+ s.Require().NoError(s.repo.DeleteUserAvatar(s.ctx, user.ID))
+ loadedAvatar, err = s.repo.GetUserAvatar(s.ctx, user.ID)
+ s.Require().NoError(err)
+ s.Require().Nil(loadedAvatar)
+}
+
+func (s *UserProfileIdentityRepoSuite) TestAuthIdentityMigrationReportHelpers_ListAndSummarize() {
+ _, err := integrationDB.ExecContext(s.ctx, `
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details, created_at)
+VALUES
+ ('wechat_openid_only_requires_remediation', 'u-1', '{"user_id":1}'::jsonb, '2026-04-20T10:00:00Z'),
+ ('wechat_openid_only_requires_remediation', 'u-2', '{"user_id":2}'::jsonb, '2026-04-20T11:00:00Z'),
+ ('oidc_synthetic_email_requires_manual_recovery', 'u-3', '{"user_id":3}'::jsonb, '2026-04-20T12:00:00Z')`)
+ s.Require().NoError(err)
+
+ summary, err := s.repo.SummarizeAuthIdentityMigrationReports(s.ctx)
+ s.Require().NoError(err)
+ s.Require().Equal(int64(3), summary.Total)
+ s.Require().Equal(int64(2), summary.ByType["wechat_openid_only_requires_remediation"])
+ s.Require().Equal(int64(1), summary.ByType["oidc_synthetic_email_requires_manual_recovery"])
+
+ reports, err := s.repo.ListAuthIdentityMigrationReports(s.ctx, AuthIdentityMigrationReportQuery{
+ ReportType: "wechat_openid_only_requires_remediation",
+ Limit: 10,
+ })
+ s.Require().NoError(err)
+ s.Require().Len(reports, 2)
+ s.Require().Equal("u-2", reports[0].ReportKey)
+ s.Require().Equal(float64(2), reports[0].Details["user_id"])
+
+ report, err := s.repo.GetAuthIdentityMigrationReport(s.ctx, "oidc_synthetic_email_requires_manual_recovery", "u-3")
+ s.Require().NoError(err)
+ s.Require().Equal("u-3", report.ReportKey)
+ s.Require().Equal(float64(3), report.Details["user_id"])
+}
+
+func (s *UserProfileIdentityRepoSuite) TestUpdateUserLastLoginAndActiveAt_UsesDedicatedColumns() {
+ user := s.mustCreateUser("activity")
+ loginAt := time.Date(2026, 4, 20, 8, 0, 0, 0, time.UTC)
+ activeAt := loginAt.Add(5 * time.Minute)
+
+ s.Require().NoError(s.repo.UpdateUserLastLoginAt(s.ctx, user.ID, loginAt))
+ s.Require().NoError(s.repo.UpdateUserLastActiveAt(s.ctx, user.ID, activeAt))
+
+ var storedLoginAt sqlNullTime
+ var storedActiveAt sqlNullTime
+ s.Require().NoError(integrationDB.QueryRowContext(s.ctx, `
+SELECT last_login_at, last_active_at
+FROM users
+WHERE id = $1`,
+ user.ID,
+ ).Scan(&storedLoginAt, &storedActiveAt))
+ s.Require().True(storedLoginAt.Valid)
+ s.Require().True(storedActiveAt.Valid)
+ s.Require().True(storedLoginAt.Time.Equal(loginAt))
+ s.Require().True(storedActiveAt.Time.Equal(activeAt))
+}
+
+type sqlNullTime struct {
+ Time time.Time
+ Valid bool
+}
+
+func (t *sqlNullTime) Scan(value any) error {
+ switch v := value.(type) {
+ case time.Time:
+ t.Time = v
+ t.Valid = true
+ return nil
+ case nil:
+ t.Time = time.Time{}
+ t.Valid = false
+ return nil
+ default:
+ return fmt.Errorf("unsupported scan type %T", value)
+ }
+}
+
+func stringPtr(v string) *string {
+ return &v
+}
diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go
index 913e1c40..0c607ecc 100644
--- a/backend/internal/repository/user_repo.go
+++ b/backend/internal/repository/user_repo.go
@@ -64,6 +64,9 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
SetBalance(userIn.Balance).
SetConcurrency(userIn.Concurrency).
SetStatus(userIn.Status).
+ SetSignupSource(userSignupSourceOrDefault(userIn.SignupSource)).
+ SetNillableLastLoginAt(userIn.LastLoginAt).
+ SetNillableLastActiveAt(userIn.LastActiveAt).
Save(ctx)
if err != nil {
return translatePersistenceError(err, nil, service.ErrEmailExists)
@@ -151,6 +154,15 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
SetNillableBalanceNotifyThreshold(userIn.BalanceNotifyThreshold).
SetBalanceNotifyExtraEmails(marshalExtraEmails(userIn.BalanceNotifyExtraEmails)).
SetTotalRecharged(userIn.TotalRecharged)
+ if userIn.SignupSource != "" {
+ updateOp = updateOp.SetSignupSource(userIn.SignupSource)
+ }
+ if userIn.LastLoginAt != nil {
+ updateOp = updateOp.SetLastLoginAt(*userIn.LastLoginAt)
+ }
+ if userIn.LastActiveAt != nil {
+ updateOp = updateOp.SetLastActiveAt(*userIn.LastActiveAt)
+ }
if userIn.BalanceNotifyThreshold == nil {
updateOp = updateOp.ClearBalanceNotifyThreshold()
}
@@ -300,6 +312,7 @@ func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector)
var field string
defaultField := true
+ nullsLastField := false
switch sortBy {
case "email":
field = dbuser.FieldEmail
@@ -322,6 +335,14 @@ func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector)
case "created_at":
field = dbuser.FieldCreatedAt
defaultField = false
+ case "last_login_at":
+ field = dbuser.FieldLastLoginAt
+ defaultField = false
+ nullsLastField = true
+ case "last_active_at":
+ field = dbuser.FieldLastActiveAt
+ defaultField = false
+ nullsLastField = true
default:
field = dbuser.FieldID
}
@@ -330,11 +351,23 @@ func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector)
if defaultField && field == dbuser.FieldID {
return []func(*entsql.Selector){dbent.Asc(dbuser.FieldID)}
}
+ if nullsLastField {
+ return []func(*entsql.Selector){
+ entsql.OrderByField(field, entsql.OrderNullsLast()).ToFunc(),
+ dbent.Asc(dbuser.FieldID),
+ }
+ }
return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(dbuser.FieldID)}
}
if defaultField && field == dbuser.FieldID {
return []func(*entsql.Selector){dbent.Desc(dbuser.FieldID)}
}
+ if nullsLastField {
+ return []func(*entsql.Selector){
+ entsql.OrderByField(field, entsql.OrderDesc(), entsql.OrderNullsLast()).ToFunc(),
+ dbent.Desc(dbuser.FieldID),
+ }
+ }
return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(dbuser.FieldID)}
}
@@ -558,10 +591,21 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) {
return
}
dst.ID = src.ID
+ dst.SignupSource = src.SignupSource
+ dst.LastLoginAt = src.LastLoginAt
+ dst.LastActiveAt = src.LastActiveAt
dst.CreatedAt = src.CreatedAt
dst.UpdatedAt = src.UpdatedAt
}
+func userSignupSourceOrDefault(signupSource string) string {
+ signupSource = strings.TrimSpace(signupSource)
+ if signupSource == "" {
+ return "email"
+ }
+ return signupSource
+}
+
// marshalExtraEmails serializes notify email entries to JSON for storage.
func marshalExtraEmails(entries []service.NotifyEmailEntry) string {
return service.MarshalNotifyEmails(entries)
diff --git a/backend/internal/repository/user_repo_sort_integration_test.go b/backend/internal/repository/user_repo_sort_integration_test.go
index ab84b0e9..8abef45a 100644
--- a/backend/internal/repository/user_repo_sort_integration_test.go
+++ b/backend/internal/repository/user_repo_sort_integration_test.go
@@ -4,6 +4,7 @@ package repository
import (
"testing"
+ "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -36,4 +37,86 @@ func (s *UserRepoSuite) TestList_DefaultSortByNewestFirst() {
s.Require().Equal(first.ID, users[1].ID)
}
+func (s *UserRepoSuite) TestCreateAndRead_PreservesSignupSourceAndActivityTimestamps() {
+ lastLoginAt := time.Now().Add(-2 * time.Hour).UTC().Truncate(time.Microsecond)
+ lastActiveAt := time.Now().Add(-30 * time.Minute).UTC().Truncate(time.Microsecond)
+
+ created := s.mustCreateUser(&service.User{
+ Email: "identity-meta@example.com",
+ SignupSource: "github",
+ LastLoginAt: &lastLoginAt,
+ LastActiveAt: &lastActiveAt,
+ })
+
+ got, err := s.repo.GetByID(s.ctx, created.ID)
+ s.Require().NoError(err)
+ s.Require().Equal("github", got.SignupSource)
+ s.Require().NotNil(got.LastLoginAt)
+ s.Require().NotNil(got.LastActiveAt)
+ s.Require().True(got.LastLoginAt.Equal(lastLoginAt))
+ s.Require().True(got.LastActiveAt.Equal(lastActiveAt))
+}
+
+func (s *UserRepoSuite) TestUpdate_PersistsSignupSourceAndActivityTimestamps() {
+ created := s.mustCreateUser(&service.User{Email: "identity-update@example.com"})
+ lastLoginAt := time.Now().Add(-90 * time.Minute).UTC().Truncate(time.Microsecond)
+ lastActiveAt := time.Now().Add(-15 * time.Minute).UTC().Truncate(time.Microsecond)
+
+ created.SignupSource = "oidc"
+ created.LastLoginAt = &lastLoginAt
+ created.LastActiveAt = &lastActiveAt
+
+ s.Require().NoError(s.repo.Update(s.ctx, created))
+
+ got, err := s.repo.GetByID(s.ctx, created.ID)
+ s.Require().NoError(err)
+ s.Require().Equal("oidc", got.SignupSource)
+ s.Require().NotNil(got.LastLoginAt)
+ s.Require().NotNil(got.LastActiveAt)
+ s.Require().True(got.LastLoginAt.Equal(lastLoginAt))
+ s.Require().True(got.LastActiveAt.Equal(lastActiveAt))
+}
+
+func (s *UserRepoSuite) TestListWithFilters_SortByLastLoginAtDesc() {
+ older := time.Now().Add(-4 * time.Hour).UTC().Truncate(time.Microsecond)
+ newer := time.Now().Add(-1 * time.Hour).UTC().Truncate(time.Microsecond)
+
+ s.mustCreateUser(&service.User{Email: "nil-login@example.com"})
+ s.mustCreateUser(&service.User{Email: "older-login@example.com", LastLoginAt: &older})
+ s.mustCreateUser(&service.User{Email: "newer-login@example.com", LastLoginAt: &newer})
+
+ users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{
+ Page: 1,
+ PageSize: 10,
+ SortBy: "last_login_at",
+ SortOrder: "desc",
+ }, service.UserListFilters{})
+ s.Require().NoError(err)
+ s.Require().Len(users, 3)
+ s.Require().Equal("newer-login@example.com", users[0].Email)
+ s.Require().Equal("older-login@example.com", users[1].Email)
+ s.Require().Equal("nil-login@example.com", users[2].Email)
+}
+
+func (s *UserRepoSuite) TestListWithFilters_SortByLastActiveAtAsc() {
+ earlier := time.Now().Add(-3 * time.Hour).UTC().Truncate(time.Microsecond)
+ later := time.Now().Add(-45 * time.Minute).UTC().Truncate(time.Microsecond)
+
+ s.mustCreateUser(&service.User{Email: "nil-active@example.com"})
+ s.mustCreateUser(&service.User{Email: "later-active@example.com", LastActiveAt: &later})
+ s.mustCreateUser(&service.User{Email: "earlier-active@example.com", LastActiveAt: &earlier})
+
+ users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{
+ Page: 1,
+ PageSize: 10,
+ SortBy: "last_active_at",
+ SortOrder: "asc",
+ }, service.UserListFilters{})
+ s.Require().NoError(err)
+ s.Require().Len(users, 3)
+ s.Require().Equal("earlier-active@example.com", users[0].Email)
+ s.Require().Equal("later-active@example.com", users[1].Email)
+ s.Require().Equal("nil-active@example.com", users[2].Email)
+}
+
func TestUserRepoSortSuiteSmoke(_ *testing.T) {}
diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go
index b686b986..e903898f 100644
--- a/backend/internal/server/api_contract_test.go
+++ b/backend/internal/server/api_contract_test.go
@@ -479,7 +479,7 @@ func TestAPIContracts(t *testing.T) {
service.SettingKeyOIDCConnectRedirectURL: "",
service.SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback",
service.SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post",
- service.SettingKeyOIDCConnectUsePKCE: "false",
+ service.SettingKeyOIDCConnectUsePKCE: "true",
service.SettingKeyOIDCConnectValidateIDToken: "true",
service.SettingKeyOIDCConnectAllowedSigningAlgs: "RS256,ES256,PS256",
service.SettingKeyOIDCConnectClockSkewSeconds: "120",
@@ -549,7 +549,7 @@ func TestAPIContracts(t *testing.T) {
"oidc_connect_redirect_url": "",
"oidc_connect_frontend_redirect_url": "/auth/oidc/callback",
"oidc_connect_token_auth_method": "client_secret_post",
- "oidc_connect_use_pkce": false,
+ "oidc_connect_use_pkce": true,
"oidc_connect_validate_id_token": true,
"oidc_connect_allowed_signing_algs": "RS256,ES256,PS256",
"oidc_connect_clock_skew_seconds": 120,
diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go
index c143b030..911a4064 100644
--- a/backend/internal/server/routes/auth.go
+++ b/backend/internal/server/routes/auth.go
@@ -64,12 +64,26 @@ func RegisterAuthRoutes(
}), h.Auth.ResetPassword)
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
+ auth.GET("/oauth/wechat/start", h.Auth.WeChatOAuthStart)
+ auth.GET("/oauth/wechat/callback", h.Auth.WeChatOAuthCallback)
+ auth.POST("/oauth/pending/exchange",
+ rateLimiter.LimitWithOptions("oauth-pending-exchange", 20, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.ExchangePendingOAuthCompletion,
+ )
auth.POST("/oauth/linuxdo/complete-registration",
rateLimiter.LimitWithOptions("oauth-linuxdo-complete", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.CompleteLinuxDoOAuthRegistration,
)
+ auth.POST("/oauth/wechat/complete-registration",
+ rateLimiter.LimitWithOptions("oauth-wechat-complete", 10, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.CompleteWeChatOAuthRegistration,
+ )
auth.GET("/oauth/oidc/start", h.Auth.OIDCOAuthStart)
auth.GET("/oauth/oidc/callback", h.Auth.OIDCOAuthCallback)
auth.POST("/oauth/oidc/complete-registration",
diff --git a/backend/internal/service/admin_service_apikey_test.go b/backend/internal/service/admin_service_apikey_test.go
index 419ddbc3..b802a9c2 100644
--- a/backend/internal/service/admin_service_apikey_test.go
+++ b/backend/internal/service/admin_service_apikey_test.go
@@ -44,6 +44,15 @@ func (s *userRepoStubForGroupUpdate) GetFirstAdmin(context.Context) (*User, erro
}
func (s *userRepoStubForGroupUpdate) Update(context.Context, *User) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
+func (s *userRepoStubForGroupUpdate) GetUserAvatar(context.Context, int64) (*UserAvatar, error) {
+ panic("unexpected")
+}
+func (s *userRepoStubForGroupUpdate) UpsertUserAvatar(context.Context, int64, UpsertUserAvatarInput) (*UserAvatar, error) {
+ panic("unexpected")
+}
+func (s *userRepoStubForGroupUpdate) DeleteUserAvatar(context.Context, int64) error {
+ panic("unexpected")
+}
func (s *userRepoStubForGroupUpdate) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
panic("unexpected")
}
diff --git a/backend/internal/service/admin_service_delete_test.go b/backend/internal/service/admin_service_delete_test.go
index fbc856cf..323286b0 100644
--- a/backend/internal/service/admin_service_delete_test.go
+++ b/backend/internal/service/admin_service_delete_test.go
@@ -62,6 +62,18 @@ func (s *userRepoStub) Delete(ctx context.Context, id int64) error {
return s.deleteErr
}
+func (s *userRepoStub) GetUserAvatar(ctx context.Context, userID int64) (*UserAvatar, error) {
+ panic("unexpected GetUserAvatar call")
+}
+
+func (s *userRepoStub) UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error) {
+ panic("unexpected UpsertUserAvatar call")
+}
+
+func (s *userRepoStub) DeleteUserAvatar(ctx context.Context, userID int64) error {
+ panic("unexpected DeleteUserAvatar call")
+}
+
func (s *userRepoStub) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
diff --git a/backend/internal/service/auth_pending_identity_service.go b/backend/internal/service/auth_pending_identity_service.go
new file mode 100644
index 00000000..b7e86e12
--- /dev/null
+++ b/backend/internal/service/auth_pending_identity_service.go
@@ -0,0 +1,326 @@
+package service
+
+import (
+ "context"
+ "crypto/rand"
+ "crypto/sha256"
+ "encoding/hex"
+ "fmt"
+ "strings"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+)
+
+var (
+ ErrPendingAuthSessionNotFound = infraerrors.NotFound("PENDING_AUTH_SESSION_NOT_FOUND", "pending auth session not found")
+ ErrPendingAuthSessionExpired = infraerrors.Unauthorized("PENDING_AUTH_SESSION_EXPIRED", "pending auth session has expired")
+ ErrPendingAuthSessionConsumed = infraerrors.Unauthorized("PENDING_AUTH_SESSION_CONSUMED", "pending auth session has already been used")
+ ErrPendingAuthCodeInvalid = infraerrors.Unauthorized("PENDING_AUTH_CODE_INVALID", "pending auth completion code is invalid")
+ ErrPendingAuthCodeExpired = infraerrors.Unauthorized("PENDING_AUTH_CODE_EXPIRED", "pending auth completion code has expired")
+ ErrPendingAuthCodeConsumed = infraerrors.Unauthorized("PENDING_AUTH_CODE_CONSUMED", "pending auth completion code has already been used")
+ ErrPendingAuthBrowserMismatch = infraerrors.Unauthorized("PENDING_AUTH_BROWSER_MISMATCH", "pending auth completion code does not match this browser session")
+)
+
+const (
+ defaultPendingAuthTTL = 15 * time.Minute
+ defaultPendingAuthCompletionTTL = 5 * time.Minute
+)
+
+type PendingAuthIdentityKey struct {
+ ProviderType string
+ ProviderKey string
+ ProviderSubject string
+}
+
+type CreatePendingAuthSessionInput struct {
+ SessionToken string
+ Intent string
+ Identity PendingAuthIdentityKey
+ TargetUserID *int64
+ RedirectTo string
+ ResolvedEmail string
+ RegistrationPasswordHash string
+ BrowserSessionKey string
+ UpstreamIdentityClaims map[string]any
+ LocalFlowState map[string]any
+ ExpiresAt time.Time
+}
+
+type IssuePendingAuthCompletionCodeInput struct {
+ PendingAuthSessionID int64
+ BrowserSessionKey string
+ TTL time.Duration
+}
+
+type IssuePendingAuthCompletionCodeResult struct {
+ Code string
+ ExpiresAt time.Time
+}
+
+type PendingIdentityAdoptionDecisionInput struct {
+ PendingAuthSessionID int64
+ IdentityID *int64
+ AdoptDisplayName bool
+ AdoptAvatar bool
+}
+
+type AuthPendingIdentityService struct {
+ entClient *dbent.Client
+}
+
+func NewAuthPendingIdentityService(entClient *dbent.Client) *AuthPendingIdentityService {
+ return &AuthPendingIdentityService{entClient: entClient}
+}
+
+func (s *AuthPendingIdentityService) CreatePendingSession(ctx context.Context, input CreatePendingAuthSessionInput) (*dbent.PendingAuthSession, error) {
+ if s == nil || s.entClient == nil {
+ return nil, fmt.Errorf("pending auth ent client is not configured")
+ }
+
+ sessionToken := strings.TrimSpace(input.SessionToken)
+ if sessionToken == "" {
+ var err error
+ sessionToken, err = randomOpaqueToken(24)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ expiresAt := input.ExpiresAt.UTC()
+ if expiresAt.IsZero() {
+ expiresAt = time.Now().UTC().Add(defaultPendingAuthTTL)
+ }
+
+ create := s.entClient.PendingAuthSession.Create().
+ SetSessionToken(sessionToken).
+ SetIntent(strings.TrimSpace(input.Intent)).
+ SetProviderType(strings.TrimSpace(input.Identity.ProviderType)).
+ SetProviderKey(strings.TrimSpace(input.Identity.ProviderKey)).
+ SetProviderSubject(strings.TrimSpace(input.Identity.ProviderSubject)).
+ SetRedirectTo(strings.TrimSpace(input.RedirectTo)).
+ SetResolvedEmail(strings.TrimSpace(input.ResolvedEmail)).
+ SetRegistrationPasswordHash(strings.TrimSpace(input.RegistrationPasswordHash)).
+ SetBrowserSessionKey(strings.TrimSpace(input.BrowserSessionKey)).
+ SetUpstreamIdentityClaims(copyPendingMap(input.UpstreamIdentityClaims)).
+ SetLocalFlowState(copyPendingMap(input.LocalFlowState)).
+ SetExpiresAt(expiresAt)
+ if input.TargetUserID != nil {
+ create = create.SetTargetUserID(*input.TargetUserID)
+ }
+ return create.Save(ctx)
+}
+
+func (s *AuthPendingIdentityService) IssueCompletionCode(ctx context.Context, input IssuePendingAuthCompletionCodeInput) (*IssuePendingAuthCompletionCodeResult, error) {
+ if s == nil || s.entClient == nil {
+ return nil, fmt.Errorf("pending auth ent client is not configured")
+ }
+
+ session, err := s.entClient.PendingAuthSession.Get(ctx, input.PendingAuthSessionID)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, ErrPendingAuthSessionNotFound
+ }
+ return nil, err
+ }
+
+ code, err := randomOpaqueToken(24)
+ if err != nil {
+ return nil, err
+ }
+ ttl := input.TTL
+ if ttl <= 0 {
+ ttl = defaultPendingAuthCompletionTTL
+ }
+ expiresAt := time.Now().UTC().Add(ttl)
+
+ update := s.entClient.PendingAuthSession.UpdateOneID(session.ID).
+ SetCompletionCodeHash(hashPendingAuthCode(code)).
+ SetCompletionCodeExpiresAt(expiresAt)
+ if strings.TrimSpace(input.BrowserSessionKey) != "" {
+ update = update.SetBrowserSessionKey(strings.TrimSpace(input.BrowserSessionKey))
+ }
+ if _, err := update.Save(ctx); err != nil {
+ return nil, err
+ }
+
+ return &IssuePendingAuthCompletionCodeResult{
+ Code: code,
+ ExpiresAt: expiresAt,
+ }, nil
+}
+
+func (s *AuthPendingIdentityService) ConsumeCompletionCode(ctx context.Context, rawCode, browserSessionKey string) (*dbent.PendingAuthSession, error) {
+ if s == nil || s.entClient == nil {
+ return nil, fmt.Errorf("pending auth ent client is not configured")
+ }
+
+ codeHash := hashPendingAuthCode(strings.TrimSpace(rawCode))
+ session, err := s.entClient.PendingAuthSession.Query().
+ Where(pendingauthsession.CompletionCodeHashEQ(codeHash)).
+ Only(ctx)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, ErrPendingAuthCodeInvalid
+ }
+ return nil, err
+ }
+
+ return s.consumeSession(ctx, session, browserSessionKey, ErrPendingAuthCodeExpired, ErrPendingAuthCodeConsumed)
+}
+
+func (s *AuthPendingIdentityService) ConsumeBrowserSession(ctx context.Context, sessionToken, browserSessionKey string) (*dbent.PendingAuthSession, error) {
+ if s == nil || s.entClient == nil {
+ return nil, fmt.Errorf("pending auth ent client is not configured")
+ }
+
+ session, err := s.getBrowserSession(ctx, sessionToken)
+ if err != nil {
+ return nil, err
+ }
+
+ return s.consumeSession(ctx, session, browserSessionKey, ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed)
+}
+
+func (s *AuthPendingIdentityService) GetBrowserSession(ctx context.Context, sessionToken, browserSessionKey string) (*dbent.PendingAuthSession, error) {
+ if s == nil || s.entClient == nil {
+ return nil, fmt.Errorf("pending auth ent client is not configured")
+ }
+
+ session, err := s.getBrowserSession(ctx, sessionToken)
+ if err != nil {
+ return nil, err
+ }
+ if err := validatePendingSessionState(session, browserSessionKey, ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed); err != nil {
+ return nil, err
+ }
+ return session, nil
+}
+
+func (s *AuthPendingIdentityService) getBrowserSession(ctx context.Context, sessionToken string) (*dbent.PendingAuthSession, error) {
+ if s == nil || s.entClient == nil {
+ return nil, fmt.Errorf("pending auth ent client is not configured")
+ }
+
+ sessionToken = strings.TrimSpace(sessionToken)
+ if sessionToken == "" {
+ return nil, ErrPendingAuthSessionNotFound
+ }
+
+ session, err := s.entClient.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(sessionToken)).
+ Only(ctx)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, ErrPendingAuthSessionNotFound
+ }
+ return nil, err
+ }
+ return session, nil
+}
+
+func (s *AuthPendingIdentityService) consumeSession(
+ ctx context.Context,
+ session *dbent.PendingAuthSession,
+ browserSessionKey string,
+ expiredErr error,
+ consumedErr error,
+) (*dbent.PendingAuthSession, error) {
+ if err := validatePendingSessionState(session, browserSessionKey, expiredErr, consumedErr); err != nil {
+ return nil, err
+ }
+
+ now := time.Now().UTC()
+ updated, err := s.entClient.PendingAuthSession.UpdateOneID(session.ID).
+ SetConsumedAt(now).
+ SetCompletionCodeHash("").
+ ClearCompletionCodeExpiresAt().
+ Save(ctx)
+ if err != nil {
+ return nil, err
+ }
+ return updated, nil
+}
+
+func validatePendingSessionState(session *dbent.PendingAuthSession, browserSessionKey string, expiredErr error, consumedErr error) error {
+ if session == nil {
+ return ErrPendingAuthSessionNotFound
+ }
+
+ now := time.Now().UTC()
+ if session.ConsumedAt != nil {
+ return consumedErr
+ }
+ if !session.ExpiresAt.IsZero() && now.After(session.ExpiresAt) {
+ return expiredErr
+ }
+ if session.CompletionCodeExpiresAt != nil && now.After(*session.CompletionCodeExpiresAt) {
+ return expiredErr
+ }
+ if strings.TrimSpace(session.BrowserSessionKey) != "" && strings.TrimSpace(browserSessionKey) != strings.TrimSpace(session.BrowserSessionKey) {
+ return ErrPendingAuthBrowserMismatch
+ }
+ return nil
+}
+
+func (s *AuthPendingIdentityService) UpsertAdoptionDecision(ctx context.Context, input PendingIdentityAdoptionDecisionInput) (*dbent.IdentityAdoptionDecision, error) {
+ if s == nil || s.entClient == nil {
+ return nil, fmt.Errorf("pending auth ent client is not configured")
+ }
+
+ existing, err := s.entClient.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(input.PendingAuthSessionID)).
+ Only(ctx)
+ if err != nil && !dbent.IsNotFound(err) {
+ return nil, err
+ }
+ if existing == nil {
+ create := s.entClient.IdentityAdoptionDecision.Create().
+ SetPendingAuthSessionID(input.PendingAuthSessionID).
+ SetAdoptDisplayName(input.AdoptDisplayName).
+ SetAdoptAvatar(input.AdoptAvatar).
+ SetDecidedAt(time.Now().UTC())
+ if input.IdentityID != nil {
+ create = create.SetIdentityID(*input.IdentityID)
+ }
+ return create.Save(ctx)
+ }
+
+ update := s.entClient.IdentityAdoptionDecision.UpdateOneID(existing.ID).
+ SetAdoptDisplayName(input.AdoptDisplayName).
+ SetAdoptAvatar(input.AdoptAvatar)
+ if input.IdentityID != nil {
+ update = update.SetIdentityID(*input.IdentityID)
+ }
+ return update.Save(ctx)
+}
+
+func copyPendingMap(in map[string]any) map[string]any {
+ if len(in) == 0 {
+ return map[string]any{}
+ }
+ out := make(map[string]any, len(in))
+ for k, v := range in {
+ out[k] = v
+ }
+ return out
+}
+
+func randomOpaqueToken(byteLen int) (string, error) {
+ if byteLen <= 0 {
+ byteLen = 16
+ }
+ buf := make([]byte, byteLen)
+ if _, err := rand.Read(buf); err != nil {
+ return "", err
+ }
+ return hex.EncodeToString(buf), nil
+}
+
+func hashPendingAuthCode(code string) string {
+ sum := sha256.Sum256([]byte(code))
+ return hex.EncodeToString(sum[:])
+}
diff --git a/backend/internal/service/auth_pending_identity_service_test.go b/backend/internal/service/auth_pending_identity_service_test.go
new file mode 100644
index 00000000..c69ebfd2
--- /dev/null
+++ b/backend/internal/service/auth_pending_identity_service_test.go
@@ -0,0 +1,224 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "database/sql"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
+ "github.com/stretchr/testify/require"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
+)
+
+func newAuthPendingIdentityServiceTestClient(t *testing.T) (*AuthPendingIdentityService, *dbent.Client) {
+ t.Helper()
+
+ db, err := sql.Open("sqlite", "file:auth_pending_identity_service?mode=memory&cache=shared")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+ t.Cleanup(func() { _ = client.Close() })
+
+ return NewAuthPendingIdentityService(client), client
+}
+
+func TestAuthPendingIdentityService_CreatePendingSessionStoresSeparatedState(t *testing.T) {
+ svc, client := newAuthPendingIdentityServiceTestClient(t)
+ ctx := context.Background()
+
+ targetUser, err := client.User.Create().
+ SetEmail("pending-target@example.com").
+ SetPasswordHash("hash").
+ SetRole(RoleUser).
+ SetStatus(StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
+ Intent: "bind_current_user",
+ Identity: PendingAuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-open",
+ ProviderSubject: "union-123",
+ },
+ TargetUserID: &targetUser.ID,
+ RedirectTo: "/profile",
+ ResolvedEmail: "user@example.com",
+ BrowserSessionKey: "browser-1",
+ UpstreamIdentityClaims: map[string]any{"nickname": "wx-user", "avatar_url": "https://cdn.example/avatar.png"},
+ LocalFlowState: map[string]any{"step": "email_required"},
+ })
+ require.NoError(t, err)
+ require.NotEmpty(t, session.SessionToken)
+ require.Equal(t, "bind_current_user", session.Intent)
+ require.Equal(t, "wechat", session.ProviderType)
+ require.NotNil(t, session.TargetUserID)
+ require.Equal(t, targetUser.ID, *session.TargetUserID)
+ require.Equal(t, "wx-user", session.UpstreamIdentityClaims["nickname"])
+ require.Equal(t, "email_required", session.LocalFlowState["step"])
+}
+
+func TestAuthPendingIdentityService_CompletionCodeIsBrowserBoundAndOneTime(t *testing.T) {
+ svc, _ := newAuthPendingIdentityServiceTestClient(t)
+ ctx := context.Background()
+
+ session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
+ Intent: "login",
+ Identity: PendingAuthIdentityKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo-main",
+ ProviderSubject: "subject-1",
+ },
+ BrowserSessionKey: "browser-expected",
+ UpstreamIdentityClaims: map[string]any{"nickname": "linux-user"},
+ LocalFlowState: map[string]any{"step": "pending"},
+ })
+ require.NoError(t, err)
+
+ issued, err := svc.IssueCompletionCode(ctx, IssuePendingAuthCompletionCodeInput{
+ PendingAuthSessionID: session.ID,
+ BrowserSessionKey: "browser-expected",
+ })
+ require.NoError(t, err)
+ require.NotEmpty(t, issued.Code)
+
+ _, err = svc.ConsumeCompletionCode(ctx, issued.Code, "browser-other")
+ require.ErrorIs(t, err, ErrPendingAuthBrowserMismatch)
+
+ consumed, err := svc.ConsumeCompletionCode(ctx, issued.Code, "browser-expected")
+ require.NoError(t, err)
+ require.NotNil(t, consumed.ConsumedAt)
+ require.Empty(t, consumed.CompletionCodeHash)
+ require.Nil(t, consumed.CompletionCodeExpiresAt)
+
+ _, err = svc.ConsumeCompletionCode(ctx, issued.Code, "browser-expected")
+ require.ErrorIs(t, err, ErrPendingAuthCodeInvalid)
+}
+
+func TestAuthPendingIdentityService_CompletionCodeExpires(t *testing.T) {
+ svc, client := newAuthPendingIdentityServiceTestClient(t)
+ ctx := context.Background()
+
+ session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
+ Intent: "login",
+ Identity: PendingAuthIdentityKey{
+ ProviderType: "oidc",
+ ProviderKey: "https://issuer.example",
+ ProviderSubject: "subject-1",
+ },
+ BrowserSessionKey: "browser-expired",
+ })
+ require.NoError(t, err)
+
+ issued, err := svc.IssueCompletionCode(ctx, IssuePendingAuthCompletionCodeInput{
+ PendingAuthSessionID: session.ID,
+ BrowserSessionKey: "browser-expired",
+ TTL: time.Second,
+ })
+ require.NoError(t, err)
+
+ _, err = client.PendingAuthSession.UpdateOneID(session.ID).
+ SetCompletionCodeExpiresAt(time.Now().UTC().Add(-time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = svc.ConsumeCompletionCode(ctx, issued.Code, "browser-expired")
+ require.ErrorIs(t, err, ErrPendingAuthCodeExpired)
+}
+
+func TestAuthPendingIdentityService_UpsertAdoptionDecision(t *testing.T) {
+ svc, client := newAuthPendingIdentityServiceTestClient(t)
+ ctx := context.Background()
+
+ user, err := client.User.Create().
+ SetEmail("adoption@example.com").
+ SetPasswordHash("hash").
+ SetRole(RoleUser).
+ SetStatus(StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ identity, err := client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("wechat").
+ SetProviderKey("wechat-open").
+ SetProviderSubject("union-adoption").
+ SetMetadata(map[string]any{}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
+ Intent: "bind_current_user",
+ Identity: PendingAuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-open",
+ ProviderSubject: "union-adoption",
+ },
+ })
+ require.NoError(t, err)
+
+ first, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{
+ PendingAuthSessionID: session.ID,
+ AdoptDisplayName: true,
+ AdoptAvatar: false,
+ })
+ require.NoError(t, err)
+ require.True(t, first.AdoptDisplayName)
+ require.False(t, first.AdoptAvatar)
+ require.Nil(t, first.IdentityID)
+
+ second, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{
+ PendingAuthSessionID: session.ID,
+ IdentityID: &identity.ID,
+ AdoptDisplayName: true,
+ AdoptAvatar: true,
+ })
+ require.NoError(t, err)
+ require.Equal(t, first.ID, second.ID)
+ require.NotNil(t, second.IdentityID)
+ require.Equal(t, identity.ID, *second.IdentityID)
+ require.True(t, second.AdoptAvatar)
+}
+
+func TestAuthPendingIdentityService_ConsumeBrowserSession(t *testing.T) {
+ svc, _ := newAuthPendingIdentityServiceTestClient(t)
+ ctx := context.Background()
+
+ session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
+ Intent: "login",
+ Identity: PendingAuthIdentityKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: "subject-session-token",
+ },
+ BrowserSessionKey: "browser-session",
+ LocalFlowState: map[string]any{
+ "completion_response": map[string]any{
+ "access_token": "token",
+ },
+ },
+ })
+ require.NoError(t, err)
+
+ _, err = svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-other")
+ require.ErrorIs(t, err, ErrPendingAuthBrowserMismatch)
+
+ consumed, err := svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-session")
+ require.NoError(t, err)
+ require.NotNil(t, consumed.ConsumedAt)
+
+ _, err = svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-session")
+ require.ErrorIs(t, err, ErrPendingAuthSessionConsumed)
+}
diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go
index fd28cd42..962009ce 100644
--- a/backend/internal/service/auth_service.go
+++ b/backend/internal/service/auth_service.go
@@ -13,6 +13,7 @@ import (
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
@@ -106,6 +107,13 @@ func NewAuthService(
}
}
+func (s *AuthService) EntClient() *dbent.Client {
+ if s == nil {
+ return nil
+ }
+ return s.entClient
+}
+
// Register 用户注册,返回token和用户
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
return s.RegisterWithVerification(ctx, email, password, "", "", "")
@@ -205,6 +213,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
logger.LegacyPrintf("service.auth", "[Auth] Database error creating user: %v", err)
return "", nil, ErrServiceUnavailable
}
+ s.postAuthUserBootstrap(ctx, user, "email", true)
s.assignDefaultSubscriptions(ctx, user.ID)
// 标记邀请码为已使用(如果使用了邀请码)
@@ -421,6 +430,7 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string
if !user.IsActive() {
return "", nil, ErrUserNotActive
}
+ s.touchUserLogin(ctx, user.ID)
// 生成JWT token
token, err := s.GenerateToken(user)
@@ -501,6 +511,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
}
} else {
user = newUser
+ s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true)
s.assignDefaultSubscriptions(ctx, user.ID)
}
} else {
@@ -520,6 +531,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err)
}
}
+ s.touchUserLogin(ctx, user.ID)
token, err := s.GenerateToken(user)
if err != nil {
@@ -630,6 +642,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
return nil, nil, ErrServiceUnavailable
}
user = newUser
+ s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true)
s.assignDefaultSubscriptions(ctx, user.ID)
}
} else {
@@ -646,6 +659,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
}
} else {
user = newUser
+ s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true)
s.assignDefaultSubscriptions(ctx, user.ID)
if invitationRedeemCode != nil {
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
@@ -670,6 +684,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err)
}
}
+ s.touchUserLogin(ctx, user.ID)
tokenPair, err := s.GenerateTokenPair(ctx, user, "")
if err != nil {
@@ -678,63 +693,6 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
return tokenPair, user, nil
}
-// pendingOAuthTokenTTL is the validity period for pending OAuth tokens.
-const pendingOAuthTokenTTL = 10 * time.Minute
-
-// pendingOAuthPurpose is the purpose claim value for pending OAuth registration tokens.
-const pendingOAuthPurpose = "pending_oauth_registration"
-
-type pendingOAuthClaims struct {
- Email string `json:"email"`
- Username string `json:"username"`
- Purpose string `json:"purpose"`
- jwt.RegisteredClaims
-}
-
-// CreatePendingOAuthToken generates a short-lived JWT that carries the OAuth identity
-// while waiting for the user to supply an invitation code.
-func (s *AuthService) CreatePendingOAuthToken(email, username string) (string, error) {
- now := time.Now()
- claims := &pendingOAuthClaims{
- Email: email,
- Username: username,
- Purpose: pendingOAuthPurpose,
- RegisteredClaims: jwt.RegisteredClaims{
- ExpiresAt: jwt.NewNumericDate(now.Add(pendingOAuthTokenTTL)),
- IssuedAt: jwt.NewNumericDate(now),
- NotBefore: jwt.NewNumericDate(now),
- },
- }
- token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
- return token.SignedString([]byte(s.cfg.JWT.Secret))
-}
-
-// VerifyPendingOAuthToken validates a pending OAuth token and returns the embedded identity.
-// Returns ErrInvalidToken when the token is invalid or expired.
-func (s *AuthService) VerifyPendingOAuthToken(tokenStr string) (email, username string, err error) {
- if len(tokenStr) > maxTokenLength {
- return "", "", ErrInvalidToken
- }
- parser := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name}))
- token, parseErr := parser.ParseWithClaims(tokenStr, &pendingOAuthClaims{}, func(t *jwt.Token) (any, error) {
- if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
- return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
- }
- return []byte(s.cfg.JWT.Secret), nil
- })
- if parseErr != nil {
- return "", "", ErrInvalidToken
- }
- claims, ok := token.Claims.(*pendingOAuthClaims)
- if !ok || !token.Valid {
- return "", "", ErrInvalidToken
- }
- if claims.Purpose != pendingOAuthPurpose {
- return "", "", ErrInvalidToken
- }
- return claims.Email, claims.Username, nil
-}
-
func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) {
if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 {
return
@@ -752,6 +710,95 @@ func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int
}
}
+func (s *AuthService) postAuthUserBootstrap(ctx context.Context, user *User, signupSource string, touchLogin bool) {
+ if user == nil || user.ID <= 0 {
+ return
+ }
+
+ if strings.TrimSpace(signupSource) == "" {
+ signupSource = "email"
+ }
+ s.updateUserSignupSource(ctx, user.ID, signupSource)
+
+ if signupSource == "email" {
+ s.ensureEmailAuthIdentity(ctx, user)
+ }
+ if touchLogin {
+ s.touchUserLogin(ctx, user.ID)
+ }
+}
+
+func (s *AuthService) updateUserSignupSource(ctx context.Context, userID int64, signupSource string) {
+ if s == nil || s.entClient == nil || userID <= 0 {
+ return
+ }
+ if strings.TrimSpace(signupSource) == "" {
+ return
+ }
+ if err := s.entClient.User.UpdateOneID(userID).
+ SetSignupSource(signupSource).
+ Exec(ctx); err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to update signup source: user_id=%d source=%s err=%v", userID, signupSource, err)
+ }
+}
+
+func (s *AuthService) touchUserLogin(ctx context.Context, userID int64) {
+ if s == nil || s.entClient == nil || userID <= 0 {
+ return
+ }
+ now := time.Now().UTC()
+ if err := s.entClient.User.UpdateOneID(userID).
+ SetLastLoginAt(now).
+ SetLastActiveAt(now).
+ Exec(ctx); err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to touch login timestamps: user_id=%d err=%v", userID, err)
+ }
+}
+
+func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) {
+ if s == nil || s.entClient == nil || user == nil || user.ID <= 0 {
+ return
+ }
+
+ email := strings.ToLower(strings.TrimSpace(user.Email))
+ if email == "" || isReservedEmail(email) {
+ return
+ }
+
+ if err := s.entClient.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("email").
+ SetProviderKey("email").
+ SetProviderSubject(email).
+ SetVerifiedAt(time.Now().UTC()).
+ SetMetadata(map[string]any{
+ "source": "auth_service_dual_write",
+ }).
+ OnConflictColumns(
+ authidentity.FieldProviderType,
+ authidentity.FieldProviderKey,
+ authidentity.FieldProviderSubject,
+ ).
+ DoNothing().
+ Exec(ctx); err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to ensure email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
+ }
+}
+
+func inferLegacySignupSource(email string) string {
+ normalized := strings.ToLower(strings.TrimSpace(email))
+ switch {
+ case strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain):
+ return "linuxdo"
+ case strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain):
+ return "oidc"
+ case strings.HasSuffix(normalized, WeChatConnectSyntheticEmailDomain):
+ return "wechat"
+ default:
+ return "email"
+ }
+}
+
func (s *AuthService) validateRegistrationEmailPolicy(ctx context.Context, email string) error {
if s.settingService == nil {
return nil
@@ -834,7 +881,8 @@ func randomHexString(byteLength int) (string, error) {
func isReservedEmail(email string) bool {
normalized := strings.ToLower(strings.TrimSpace(email))
return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain) ||
- strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain)
+ strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain) ||
+ strings.HasSuffix(normalized, WeChatConnectSyntheticEmailDomain)
}
// GenerateToken 生成JWT access token
diff --git a/backend/internal/service/auth_service_identity_sync_test.go b/backend/internal/service/auth_service_identity_sync_test.go
new file mode 100644
index 00000000..5bd2b25d
--- /dev/null
+++ b/backend/internal/service/auth_service_identity_sync_test.go
@@ -0,0 +1,153 @@
+//go:build unit
+
+package service_test
+
+import (
+ "context"
+ "database/sql"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/repository"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/stretchr/testify/require"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
+)
+
+type authIdentitySettingRepoStub struct {
+ values map[string]string
+}
+
+func (s *authIdentitySettingRepoStub) Get(context.Context, string) (*service.Setting, error) {
+ panic("unexpected Get call")
+}
+
+func (s *authIdentitySettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
+ if v, ok := s.values[key]; ok {
+ return v, nil
+ }
+ return "", service.ErrSettingNotFound
+}
+
+func (s *authIdentitySettingRepoStub) Set(context.Context, string, string) error {
+ panic("unexpected Set call")
+}
+
+func (s *authIdentitySettingRepoStub) GetMultiple(context.Context, []string) (map[string]string, error) {
+ panic("unexpected GetMultiple call")
+}
+
+func (s *authIdentitySettingRepoStub) SetMultiple(context.Context, map[string]string) error {
+ panic("unexpected SetMultiple call")
+}
+
+func (s *authIdentitySettingRepoStub) GetAll(context.Context) (map[string]string, error) {
+ panic("unexpected GetAll call")
+}
+
+func (s *authIdentitySettingRepoStub) Delete(context.Context, string) error {
+ panic("unexpected Delete call")
+}
+
+func newAuthServiceWithEnt(t *testing.T) (*service.AuthService, service.UserRepository, *dbent.Client) {
+ t.Helper()
+
+ db, err := sql.Open("sqlite", "file:auth_service_identity_sync?mode=memory&cache=shared")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+ t.Cleanup(func() { _ = client.Close() })
+
+ repo := repository.NewUserRepository(client, db)
+ cfg := &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-auth-identity-secret",
+ ExpireHour: 1,
+ },
+ Default: config.DefaultConfig{
+ UserBalance: 3.5,
+ UserConcurrency: 2,
+ },
+ }
+ settingSvc := service.NewSettingService(&authIdentitySettingRepoStub{
+ values: map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ },
+ }, cfg)
+
+ svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, nil)
+ return svc, repo, client
+}
+
+func TestAuthServiceRegisterDualWritesEmailIdentity(t *testing.T) {
+ svc, _, client := newAuthServiceWithEnt(t)
+ ctx := context.Background()
+
+ token, user, err := svc.Register(ctx, "user@example.com", "password")
+ require.NoError(t, err)
+ require.NotEmpty(t, token)
+ require.NotNil(t, user)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, "email", storedUser.SignupSource)
+ require.NotNil(t, storedUser.LastLoginAt)
+ require.NotNil(t, storedUser.LastActiveAt)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("user@example.com"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, user.ID, identity.UserID)
+ require.NotNil(t, identity.VerifiedAt)
+}
+
+func TestAuthServiceLoginTouchesLastLoginAt(t *testing.T) {
+ svc, repo, client := newAuthServiceWithEnt(t)
+ ctx := context.Background()
+
+ user := &service.User{
+ Email: "login@example.com",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ Balance: 1,
+ Concurrency: 1,
+ }
+ require.NoError(t, user.SetPassword("password"))
+ require.NoError(t, repo.Create(ctx, user))
+
+ old := time.Now().Add(-2 * time.Hour).UTC().Round(time.Second)
+ _, err := client.User.UpdateOneID(user.ID).
+ SetLastLoginAt(old).
+ SetLastActiveAt(old).
+ Save(ctx)
+ require.NoError(t, err)
+
+ token, gotUser, err := svc.Login(ctx, user.Email, "password")
+ require.NoError(t, err)
+ require.NotEmpty(t, token)
+ require.NotNil(t, gotUser)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.NotNil(t, storedUser.LastLoginAt)
+ require.NotNil(t, storedUser.LastActiveAt)
+ require.True(t, storedUser.LastLoginAt.After(old))
+ require.True(t, storedUser.LastActiveAt.After(old))
+}
diff --git a/backend/internal/service/auth_service_pending_oauth_test.go b/backend/internal/service/auth_service_pending_oauth_test.go
deleted file mode 100644
index 0472e06c..00000000
--- a/backend/internal/service/auth_service_pending_oauth_test.go
+++ /dev/null
@@ -1,146 +0,0 @@
-//go:build unit
-
-package service
-
-import (
- "testing"
- "time"
-
- "github.com/Wei-Shaw/sub2api/internal/config"
- "github.com/golang-jwt/jwt/v5"
- "github.com/stretchr/testify/require"
-)
-
-func newAuthServiceForPendingOAuthTest() *AuthService {
- cfg := &config.Config{
- JWT: config.JWTConfig{
- Secret: "test-secret-pending-oauth",
- ExpireHour: 1,
- },
- }
- return NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
-}
-
-// TestVerifyPendingOAuthToken_ValidToken 验证正常签发的 pending token 可以被成功解析。
-func TestVerifyPendingOAuthToken_ValidToken(t *testing.T) {
- svc := newAuthServiceForPendingOAuthTest()
-
- token, err := svc.CreatePendingOAuthToken("user@example.com", "alice")
- require.NoError(t, err)
- require.NotEmpty(t, token)
-
- email, username, err := svc.VerifyPendingOAuthToken(token)
- require.NoError(t, err)
- require.Equal(t, "user@example.com", email)
- require.Equal(t, "alice", username)
-}
-
-// TestVerifyPendingOAuthToken_RegularJWTRejected 用普通 access token 尝试验证,应返回 ErrInvalidToken。
-func TestVerifyPendingOAuthToken_RegularJWTRejected(t *testing.T) {
- svc := newAuthServiceForPendingOAuthTest()
-
- // 签发一个普通 access token(JWTClaims,无 Purpose 字段)
- accessToken, err := svc.GenerateToken(&User{
- ID: 1,
- Email: "user@example.com",
- Role: RoleUser,
- })
- require.NoError(t, err)
-
- _, _, err = svc.VerifyPendingOAuthToken(accessToken)
- require.ErrorIs(t, err, ErrInvalidToken)
-}
-
-// TestVerifyPendingOAuthToken_WrongPurpose 手动构造 purpose 字段不匹配的 JWT,应返回 ErrInvalidToken。
-func TestVerifyPendingOAuthToken_WrongPurpose(t *testing.T) {
- svc := newAuthServiceForPendingOAuthTest()
-
- now := time.Now()
- claims := &pendingOAuthClaims{
- Email: "user@example.com",
- Username: "alice",
- Purpose: "some_other_purpose",
- RegisteredClaims: jwt.RegisteredClaims{
- ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)),
- IssuedAt: jwt.NewNumericDate(now),
- NotBefore: jwt.NewNumericDate(now),
- },
- }
- tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
- tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
- require.NoError(t, err)
-
- _, _, err = svc.VerifyPendingOAuthToken(tokenStr)
- require.ErrorIs(t, err, ErrInvalidToken)
-}
-
-// TestVerifyPendingOAuthToken_MissingPurpose 手动构造无 purpose 字段的 JWT(模拟旧 token),应返回 ErrInvalidToken。
-func TestVerifyPendingOAuthToken_MissingPurpose(t *testing.T) {
- svc := newAuthServiceForPendingOAuthTest()
-
- now := time.Now()
- claims := &pendingOAuthClaims{
- Email: "user@example.com",
- Username: "alice",
- Purpose: "", // 旧 token 无此字段,反序列化后为零值
- RegisteredClaims: jwt.RegisteredClaims{
- ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)),
- IssuedAt: jwt.NewNumericDate(now),
- NotBefore: jwt.NewNumericDate(now),
- },
- }
- tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
- tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
- require.NoError(t, err)
-
- _, _, err = svc.VerifyPendingOAuthToken(tokenStr)
- require.ErrorIs(t, err, ErrInvalidToken)
-}
-
-// TestVerifyPendingOAuthToken_ExpiredToken 过期 token 应返回 ErrInvalidToken。
-func TestVerifyPendingOAuthToken_ExpiredToken(t *testing.T) {
- svc := newAuthServiceForPendingOAuthTest()
-
- past := time.Now().Add(-1 * time.Hour)
- claims := &pendingOAuthClaims{
- Email: "user@example.com",
- Username: "alice",
- Purpose: pendingOAuthPurpose,
- RegisteredClaims: jwt.RegisteredClaims{
- ExpiresAt: jwt.NewNumericDate(past),
- IssuedAt: jwt.NewNumericDate(past.Add(-10 * time.Minute)),
- NotBefore: jwt.NewNumericDate(past.Add(-10 * time.Minute)),
- },
- }
- tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
- tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
- require.NoError(t, err)
-
- _, _, err = svc.VerifyPendingOAuthToken(tokenStr)
- require.ErrorIs(t, err, ErrInvalidToken)
-}
-
-// TestVerifyPendingOAuthToken_WrongSecret 不同密钥签发的 token 应返回 ErrInvalidToken。
-func TestVerifyPendingOAuthToken_WrongSecret(t *testing.T) {
- other := NewAuthService(nil, nil, nil, nil, &config.Config{
- JWT: config.JWTConfig{Secret: "other-secret"},
- }, nil, nil, nil, nil, nil, nil)
-
- token, err := other.CreatePendingOAuthToken("user@example.com", "alice")
- require.NoError(t, err)
-
- svc := newAuthServiceForPendingOAuthTest()
- _, _, err = svc.VerifyPendingOAuthToken(token)
- require.ErrorIs(t, err, ErrInvalidToken)
-}
-
-// TestVerifyPendingOAuthToken_TooLong 超长 token 应返回 ErrInvalidToken。
-func TestVerifyPendingOAuthToken_TooLong(t *testing.T) {
- svc := newAuthServiceForPendingOAuthTest()
- giant := make([]byte, maxTokenLength+1)
- for i := range giant {
- giant[i] = 'a'
- }
- _, _, err := svc.VerifyPendingOAuthToken(string(giant))
- require.ErrorIs(t, err, ErrInvalidToken)
-}
diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go
index cb452efb..1dddf77e 100644
--- a/backend/internal/service/domain_constants.go
+++ b/backend/internal/service/domain_constants.go
@@ -74,6 +74,9 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
// OIDCConnectSyntheticEmailDomain 是 OIDC 用户的合成邮箱后缀(RFC 保留域名)。
const OIDCConnectSyntheticEmailDomain = "@oidc-connect.invalid"
+// WeChatConnectSyntheticEmailDomain 是 WeChat Connect 用户的合成邮箱后缀(RFC 保留域名)。
+const WeChatConnectSyntheticEmailDomain = "@wechat-connect.invalid"
+
// Setting keys
const (
// 注册设置
@@ -153,6 +156,29 @@ const (
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
SettingKeyDefaultSubscriptions = "default_subscriptions" // 新用户默认订阅列表(JSON)
+ // 第三方认证来源默认授予配置
+ SettingKeyAuthSourceDefaultEmailBalance = "auth_source_default_email_balance"
+ SettingKeyAuthSourceDefaultEmailConcurrency = "auth_source_default_email_concurrency"
+ SettingKeyAuthSourceDefaultEmailSubscriptions = "auth_source_default_email_subscriptions"
+ SettingKeyAuthSourceDefaultEmailGrantOnSignup = "auth_source_default_email_grant_on_signup"
+ SettingKeyAuthSourceDefaultEmailGrantOnFirstBind = "auth_source_default_email_grant_on_first_bind"
+ SettingKeyAuthSourceDefaultLinuxDoBalance = "auth_source_default_linuxdo_balance"
+ SettingKeyAuthSourceDefaultLinuxDoConcurrency = "auth_source_default_linuxdo_concurrency"
+ SettingKeyAuthSourceDefaultLinuxDoSubscriptions = "auth_source_default_linuxdo_subscriptions"
+ SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup = "auth_source_default_linuxdo_grant_on_signup"
+ SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind = "auth_source_default_linuxdo_grant_on_first_bind"
+ SettingKeyAuthSourceDefaultOIDCBalance = "auth_source_default_oidc_balance"
+ SettingKeyAuthSourceDefaultOIDCConcurrency = "auth_source_default_oidc_concurrency"
+ SettingKeyAuthSourceDefaultOIDCSubscriptions = "auth_source_default_oidc_subscriptions"
+ SettingKeyAuthSourceDefaultOIDCGrantOnSignup = "auth_source_default_oidc_grant_on_signup"
+ SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind = "auth_source_default_oidc_grant_on_first_bind"
+ SettingKeyAuthSourceDefaultWeChatBalance = "auth_source_default_wechat_balance"
+ SettingKeyAuthSourceDefaultWeChatConcurrency = "auth_source_default_wechat_concurrency"
+ SettingKeyAuthSourceDefaultWeChatSubscriptions = "auth_source_default_wechat_subscriptions"
+ SettingKeyAuthSourceDefaultWeChatGrantOnSignup = "auth_source_default_wechat_grant_on_signup"
+ SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind = "auth_source_default_wechat_grant_on_first_bind"
+ SettingKeyForceEmailOnThirdPartySignup = "force_email_on_third_party_signup"
+
// 管理员 API Key
SettingKeyAdminAPIKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成)
diff --git a/backend/internal/service/openai_account_scheduler.go b/backend/internal/service/openai_account_scheduler.go
index 6c09e354..09e60220 100644
--- a/backend/internal/service/openai_account_scheduler.go
+++ b/backend/internal/service/openai_account_scheduler.go
@@ -13,14 +13,30 @@ import (
"sync"
"sync/atomic"
"time"
+
+ "golang.org/x/sync/singleflight"
)
const (
openAIAccountScheduleLayerPreviousResponse = "previous_response_id"
openAIAccountScheduleLayerSessionSticky = "session_hash"
openAIAccountScheduleLayerLoadBalance = "load_balance"
+ openAIAdvancedSchedulerSettingKey = "openai_advanced_scheduler_enabled"
+)
+
+const (
+ openAIAdvancedSchedulerSettingCacheTTL = 5 * time.Second
+ openAIAdvancedSchedulerSettingDBTimeout = 2 * time.Second
)
+type cachedOpenAIAdvancedSchedulerSetting struct {
+ enabled bool
+ expiresAt int64
+}
+
+var openAIAdvancedSchedulerSettingCache atomic.Value // *cachedOpenAIAdvancedSchedulerSetting
+var openAIAdvancedSchedulerSettingSF singleflight.Group
+
type OpenAIAccountScheduleRequest struct {
GroupID *int64
SessionHash string
@@ -805,10 +821,56 @@ func (s *defaultOpenAIAccountScheduler) SnapshotMetrics() OpenAIAccountScheduler
return snapshot
}
-func (s *OpenAIGatewayService) getOpenAIAccountScheduler() OpenAIAccountScheduler {
+func (s *OpenAIGatewayService) openAIAdvancedSchedulerSettingRepo() SettingRepository {
+ if s == nil || s.rateLimitService == nil || s.rateLimitService.settingService == nil {
+ return nil
+ }
+ return s.rateLimitService.settingService.settingRepo
+}
+
+func (s *OpenAIGatewayService) isOpenAIAdvancedSchedulerEnabled(ctx context.Context) bool {
+ if cached, ok := openAIAdvancedSchedulerSettingCache.Load().(*cachedOpenAIAdvancedSchedulerSetting); ok && cached != nil {
+ if time.Now().UnixNano() < cached.expiresAt {
+ return cached.enabled
+ }
+ }
+
+ result, _, _ := openAIAdvancedSchedulerSettingSF.Do(openAIAdvancedSchedulerSettingKey, func() (any, error) {
+ if cached, ok := openAIAdvancedSchedulerSettingCache.Load().(*cachedOpenAIAdvancedSchedulerSetting); ok && cached != nil {
+ if time.Now().UnixNano() < cached.expiresAt {
+ return cached.enabled, nil
+ }
+ }
+
+ enabled := false
+ if repo := s.openAIAdvancedSchedulerSettingRepo(); repo != nil {
+ dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), openAIAdvancedSchedulerSettingDBTimeout)
+ defer cancel()
+
+ value, err := repo.GetValue(dbCtx, openAIAdvancedSchedulerSettingKey)
+ if err == nil {
+ enabled = strings.EqualFold(strings.TrimSpace(value), "true")
+ }
+ }
+
+ openAIAdvancedSchedulerSettingCache.Store(&cachedOpenAIAdvancedSchedulerSetting{
+ enabled: enabled,
+ expiresAt: time.Now().Add(openAIAdvancedSchedulerSettingCacheTTL).UnixNano(),
+ })
+ return enabled, nil
+ })
+
+ enabled, _ := result.(bool)
+ return enabled
+}
+
+func (s *OpenAIGatewayService) getOpenAIAccountScheduler(ctx context.Context) OpenAIAccountScheduler {
if s == nil {
return nil
}
+ if !s.isOpenAIAdvancedSchedulerEnabled(ctx) {
+ return nil
+ }
s.openaiSchedulerOnce.Do(func() {
if s.openaiAccountStats == nil {
s.openaiAccountStats = newOpenAIAccountRuntimeStats()
@@ -820,6 +882,11 @@ func (s *OpenAIGatewayService) getOpenAIAccountScheduler() OpenAIAccountSchedule
return s.openaiScheduler
}
+func resetOpenAIAdvancedSchedulerSettingCacheForTest() {
+ openAIAdvancedSchedulerSettingCache = atomic.Value{}
+ openAIAdvancedSchedulerSettingSF = singleflight.Group{}
+}
+
func (s *OpenAIGatewayService) SelectAccountWithScheduler(
ctx context.Context,
groupID *int64,
@@ -830,7 +897,7 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler(
requiredTransport OpenAIUpstreamTransport,
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
decision := OpenAIAccountScheduleDecision{}
- scheduler := s.getOpenAIAccountScheduler()
+ scheduler := s.getOpenAIAccountScheduler(ctx)
if scheduler == nil {
selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs)
decision.Layer = openAIAccountScheduleLayerLoadBalance
@@ -856,7 +923,7 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler(
}
func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64, success bool, firstTokenMs *int) {
- scheduler := s.getOpenAIAccountScheduler()
+ scheduler := s.getOpenAIAccountScheduler(context.Background())
if scheduler == nil {
return
}
@@ -864,7 +931,7 @@ func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64
}
func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() {
- scheduler := s.getOpenAIAccountScheduler()
+ scheduler := s.getOpenAIAccountScheduler(context.Background())
if scheduler == nil {
return
}
@@ -872,7 +939,7 @@ func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() {
}
func (s *OpenAIGatewayService) SnapshotOpenAIAccountSchedulerMetrics() OpenAIAccountSchedulerMetricsSnapshot {
- scheduler := s.getOpenAIAccountScheduler()
+ scheduler := s.getOpenAIAccountScheduler(context.Background())
if scheduler == nil {
return OpenAIAccountSchedulerMetricsSnapshot{}
}
diff --git a/backend/internal/service/openai_account_scheduler_test.go b/backend/internal/service/openai_account_scheduler_test.go
index 088815ed..a54f2614 100644
--- a/backend/internal/service/openai_account_scheduler_test.go
+++ b/backend/internal/service/openai_account_scheduler_test.go
@@ -2,6 +2,7 @@ package service
import (
"context"
+ "errors"
"fmt"
"math"
"sync"
@@ -18,6 +19,202 @@ type openAISnapshotCacheStub struct {
accountsByID map[int64]*Account
}
+type schedulerTestOpenAIAccountRepo struct {
+ AccountRepository
+ accounts []Account
+}
+
+func (r schedulerTestOpenAIAccountRepo) GetByID(ctx context.Context, id int64) (*Account, error) {
+ for i := range r.accounts {
+ if r.accounts[i].ID == id {
+ return &r.accounts[i], nil
+ }
+ }
+ return nil, errors.New("account not found")
+}
+
+func (r schedulerTestOpenAIAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
+ var result []Account
+ for _, acc := range r.accounts {
+ if acc.Platform == platform {
+ result = append(result, acc)
+ }
+ }
+ return result, nil
+}
+
+func (r schedulerTestOpenAIAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) {
+ var result []Account
+ for _, acc := range r.accounts {
+ if acc.Platform == platform {
+ result = append(result, acc)
+ }
+ }
+ return result, nil
+}
+
+func (r schedulerTestOpenAIAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) {
+ return r.ListSchedulableByPlatform(ctx, platform)
+}
+
+type schedulerTestConcurrencyCache struct {
+ ConcurrencyCache
+ loadBatchErr error
+ loadMap map[int64]*AccountLoadInfo
+ acquireResults map[int64]bool
+ waitCounts map[int64]int
+ skipDefaultLoad bool
+}
+
+func (c schedulerTestConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
+ if c.acquireResults != nil {
+ if result, ok := c.acquireResults[accountID]; ok {
+ return result, nil
+ }
+ }
+ return true, nil
+}
+
+func (c schedulerTestConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
+ return nil
+}
+
+func (c schedulerTestConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
+ if c.loadBatchErr != nil {
+ return nil, c.loadBatchErr
+ }
+ out := make(map[int64]*AccountLoadInfo, len(accounts))
+ if c.skipDefaultLoad && c.loadMap != nil {
+ for _, acc := range accounts {
+ if load, ok := c.loadMap[acc.ID]; ok {
+ out[acc.ID] = load
+ }
+ }
+ return out, nil
+ }
+ for _, acc := range accounts {
+ if c.loadMap != nil {
+ if load, ok := c.loadMap[acc.ID]; ok {
+ out[acc.ID] = load
+ continue
+ }
+ }
+ out[acc.ID] = &AccountLoadInfo{AccountID: acc.ID, LoadRate: 0}
+ }
+ return out, nil
+}
+
+func (c schedulerTestConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
+ if c.waitCounts != nil {
+ if count, ok := c.waitCounts[accountID]; ok {
+ return count, nil
+ }
+ }
+ return 0, nil
+}
+
+type schedulerTestGatewayCache struct {
+ sessionBindings map[string]int64
+ deletedSessions map[string]int
+}
+
+func (c *schedulerTestGatewayCache) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
+ if id, ok := c.sessionBindings[sessionHash]; ok {
+ return id, nil
+ }
+ return 0, errors.New("not found")
+}
+
+func (c *schedulerTestGatewayCache) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error {
+ if c.sessionBindings == nil {
+ c.sessionBindings = make(map[string]int64)
+ }
+ c.sessionBindings[sessionHash] = accountID
+ return nil
+}
+
+func (c *schedulerTestGatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error {
+ return nil
+}
+
+func (c *schedulerTestGatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
+ if c.sessionBindings == nil {
+ return nil
+ }
+ if c.deletedSessions == nil {
+ c.deletedSessions = make(map[string]int)
+ }
+ c.deletedSessions[sessionHash]++
+ delete(c.sessionBindings, sessionHash)
+ return nil
+}
+
+func newSchedulerTestOpenAIWSV2Config() *config.Config {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
+ return cfg
+}
+
+type openAIAdvancedSchedulerSettingRepoStub struct {
+ values map[string]string
+}
+
+func (s *openAIAdvancedSchedulerSettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
+ value, err := s.GetValue(ctx, key)
+ if err != nil {
+ return nil, err
+ }
+ return &Setting{Key: key, Value: value}, nil
+}
+
+func (s *openAIAdvancedSchedulerSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
+ if s == nil || s.values == nil {
+ return "", ErrSettingNotFound
+ }
+ value, ok := s.values[key]
+ if !ok {
+ return "", ErrSettingNotFound
+ }
+ return value, nil
+}
+
+func (s *openAIAdvancedSchedulerSettingRepoStub) Set(context.Context, string, string) error {
+ panic("unexpected call to Set")
+}
+
+func (s *openAIAdvancedSchedulerSettingRepoStub) GetMultiple(context.Context, []string) (map[string]string, error) {
+ panic("unexpected call to GetMultiple")
+}
+
+func (s *openAIAdvancedSchedulerSettingRepoStub) SetMultiple(context.Context, map[string]string) error {
+ panic("unexpected call to SetMultiple")
+}
+
+func (s *openAIAdvancedSchedulerSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
+ panic("unexpected call to GetAll")
+}
+
+func (s *openAIAdvancedSchedulerSettingRepoStub) Delete(context.Context, string) error {
+ panic("unexpected call to Delete")
+}
+
+func newOpenAIAdvancedSchedulerRateLimitService(enabled string) *RateLimitService {
+ resetOpenAIAdvancedSchedulerSettingCacheForTest()
+ repo := &openAIAdvancedSchedulerSettingRepoStub{
+ values: map[string]string{},
+ }
+ if enabled != "" {
+ repo.values[openAIAdvancedSchedulerSettingKey] = enabled
+ }
+ return &RateLimitService{
+ settingService: NewSettingService(repo, &config.Config{}),
+ }
+}
+
func (s *openAISnapshotCacheStub) GetSnapshot(ctx context.Context, bucket SchedulerBucket) ([]*Account, bool, error) {
if len(s.snapshotAccounts) == 0 {
return nil, false, nil
@@ -45,6 +242,138 @@ func (s *openAISnapshotCacheStub) GetAccount(ctx context.Context, accountID int6
return &cloned, nil
}
+func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabledUsesLegacyLoadAwareness(t *testing.T) {
+ resetOpenAIAdvancedSchedulerSettingCacheForTest()
+
+ ctx := context.Background()
+ groupID := int64(10106)
+ accounts := []Account{
+ {
+ ID: 36001,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Priority: 5,
+ },
+ {
+ ID: 36002,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Priority: 0,
+ },
+ }
+ cfg := &config.Config{}
+ cfg.Gateway.Scheduling.LoadBatchEnabled = false
+ cache := &schedulerTestGatewayCache{}
+ svc := &OpenAIGatewayService{
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
+ cache: cache,
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
+ }
+
+ store := svc.getOpenAIWSStateStore()
+ require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_disabled_001", 36001, time.Hour))
+ require.False(t, svc.isOpenAIAdvancedSchedulerEnabled(ctx))
+
+ selection, decision, err := svc.SelectAccountWithScheduler(
+ ctx,
+ &groupID,
+ "resp_disabled_001",
+ "",
+ "gpt-5.1",
+ nil,
+ OpenAIUpstreamTransportAny,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ require.Equal(t, int64(36002), selection.Account.ID)
+ require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
+ require.False(t, decision.StickyPreviousHit)
+}
+
+func TestOpenAIGatewayService_SelectAccountWithScheduler_EnabledUsesAdvancedPreviousResponseRouting(t *testing.T) {
+ resetOpenAIAdvancedSchedulerSettingCacheForTest()
+
+ ctx := context.Background()
+ groupID := int64(10107)
+ accounts := []Account{
+ {
+ ID: 37001,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Priority: 5,
+ Extra: map[string]any{
+ "openai_apikey_responses_websockets_v2_enabled": true,
+ },
+ },
+ {
+ ID: 37002,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Priority: 0,
+ },
+ }
+ cfg := &config.Config{}
+ cfg.Gateway.Scheduling.LoadBatchEnabled = false
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
+ svc := &OpenAIGatewayService{
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
+ cache: &schedulerTestGatewayCache{},
+ cfg: cfg,
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
+ }
+
+ store := svc.getOpenAIWSStateStore()
+ require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_enabled_001", 37001, time.Hour))
+ require.True(t, svc.isOpenAIAdvancedSchedulerEnabled(ctx))
+
+ selection, decision, err := svc.SelectAccountWithScheduler(
+ ctx,
+ &groupID,
+ "resp_enabled_001",
+ "",
+ "gpt-5.1",
+ nil,
+ OpenAIUpstreamTransportAny,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ require.Equal(t, int64(37001), selection.Account.ID)
+ require.Equal(t, openAIAccountScheduleLayerPreviousResponse, decision.Layer)
+ require.True(t, decision.StickyPreviousHit)
+}
+
+func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics_DisabledNoOp(t *testing.T) {
+ resetOpenAIAdvancedSchedulerSettingCacheForTest()
+
+ svc := &OpenAIGatewayService{}
+ ttft := 120
+ svc.ReportOpenAIAccountScheduleResult(10, true, &ttft)
+ svc.RecordOpenAIAccountSwitch()
+
+ snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics()
+ require.Equal(t, OpenAIAccountSchedulerMetricsSnapshot{}, snapshot)
+}
+
func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyRateLimitedAccountFallsBackToFreshCandidate(t *testing.T) {
ctx := context.Background()
groupID := int64(10101)
@@ -53,10 +382,17 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyRateLimite
staleBackup := &Account{ID: 31002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
freshSticky := &Account{ID: 31001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil}
freshBackup := &Account{ID: 31002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
- cache := &stubGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_rate_limited": 31001}}
+ cache := &schedulerTestGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_rate_limited": 31001}}
snapshotCache := &openAISnapshotCacheStub{snapshotAccounts: []*Account{staleSticky, staleBackup}, accountsByID: map[int64]*Account{31001: freshSticky, 31002: freshBackup}}
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
- svc := &OpenAIGatewayService{accountRepo: stubOpenAIAccountRepo{accounts: []Account{*freshSticky, *freshBackup}}, cache: cache, cfg: &config.Config{}, schedulerSnapshot: snapshotService, concurrencyService: NewConcurrencyService(stubConcurrencyCache{})}
+ svc := &OpenAIGatewayService{
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{*freshSticky, *freshBackup}},
+ cache: cache,
+ cfg: &config.Config{},
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
+ schedulerSnapshot: snapshotService,
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
+ }
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_rate_limited", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
require.NoError(t, err)
@@ -76,7 +412,12 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_SkipsFreshlyRa
freshSecondary := &Account{ID: 32002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
snapshotCache := &openAISnapshotCacheStub{snapshotAccounts: []*Account{stalePrimary, staleSecondary}, accountsByID: map[int64]*Account{32001: freshPrimary, 32002: freshSecondary}}
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
- svc := &OpenAIGatewayService{accountRepo: stubOpenAIAccountRepo{accounts: []Account{*freshPrimary, *freshSecondary}}, cfg: &config.Config{}, schedulerSnapshot: snapshotService}
+ svc := &OpenAIGatewayService{
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{*freshPrimary, *freshSecondary}},
+ cfg: &config.Config{},
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
+ schedulerSnapshot: snapshotService,
+ }
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gpt-5.1", nil)
require.NoError(t, err)
@@ -92,18 +433,19 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyDBRuntimeR
staleBackup := &Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
dbSticky := Account{ID: 33001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil}
dbBackup := Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
- cache := &stubGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_db_runtime_recheck": 33001}}
+ cache := &schedulerTestGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_db_runtime_recheck": 33001}}
snapshotCache := &openAISnapshotCacheStub{
snapshotAccounts: []*Account{staleSticky, staleBackup},
accountsByID: map[int64]*Account{33001: staleSticky, 33002: staleBackup},
}
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbSticky, dbBackup}},
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{dbSticky, dbBackup}},
cache: cache,
cfg: &config.Config{},
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
schedulerSnapshot: snapshotService,
- concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_db_runtime_recheck", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
@@ -128,8 +470,9 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_DBRuntimeReche
}
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbPrimary, dbSecondary}},
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{dbPrimary, dbSecondary}},
cfg: &config.Config{},
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
schedulerSnapshot: snapshotService,
}
@@ -153,7 +496,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(
"openai_apikey_responses_websockets_v2_enabled": true,
},
}
- cache := &stubGatewayCache{}
+ cache := &schedulerTestGatewayCache{}
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
@@ -163,10 +506,11 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(
cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}},
cache: cache,
cfg: cfg,
- concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
store := svc.getOpenAIWSStateStore()
@@ -204,17 +548,18 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky(t *testin
Schedulable: true,
Concurrency: 1,
}
- cache := &stubGatewayCache{
+ cache := &schedulerTestGatewayCache{
sessionBindings: map[string]int64{
"openai:session_hash_abc": account.ID,
},
}
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}},
cache: cache,
cfg: &config.Config{},
- concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(
@@ -260,7 +605,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS
Priority: 9,
},
}
- cache := &stubGatewayCache{
+ cache := &schedulerTestGatewayCache{
sessionBindings: map[string]int64{
"openai:session_hash_sticky_busy": 21001,
},
@@ -273,7 +618,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- concurrencyCache := stubConcurrencyCache{
+ concurrencyCache := schedulerTestConcurrencyCache{
acquireResults: map[int64]bool{
21001: false, // sticky 账号已满
21002: true, // 若回退负载均衡会命中该账号(本测试要求不能切换)
@@ -288,9 +633,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS
}
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: accounts},
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: cache,
cfg: cfg,
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(concurrencyCache),
}
@@ -328,17 +674,18 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky_ForceHTTP
"openai_ws_force_http": true,
},
}
- cache := &stubGatewayCache{
+ cache := &schedulerTestGatewayCache{
sessionBindings: map[string]int64{
"openai:session_hash_force_http": account.ID,
},
}
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}},
cache: cache,
cfg: &config.Config{},
- concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(
@@ -387,15 +734,15 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStick
},
},
}
- cache := &stubGatewayCache{
+ cache := &schedulerTestGatewayCache{
sessionBindings: map[string]int64{
"openai:session_hash_ws_only": 2201,
},
}
- cfg := newOpenAIWSV2TestConfig()
+ cfg := newSchedulerTestOpenAIWSV2Config()
// 构造“HTTP-only 账号负载更低”的场景,验证 required transport 会强制过滤。
- concurrencyCache := stubConcurrencyCache{
+ concurrencyCache := schedulerTestConcurrencyCache{
loadMap: map[int64]*AccountLoadInfo{
2201: {AccountID: 2201, LoadRate: 0, WaitingCount: 0},
2202: {AccountID: 2202, LoadRate: 90, WaitingCount: 5},
@@ -403,9 +750,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStick
}
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: accounts},
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: cache,
cfg: cfg,
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(concurrencyCache),
}
@@ -445,10 +793,11 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_NoAvailabl
}
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: accounts},
- cache: &stubGatewayCache{},
- cfg: newOpenAIWSV2TestConfig(),
- concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
+ cache: &schedulerTestGatewayCache{},
+ cfg: newSchedulerTestOpenAIWSV2Config(),
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(
@@ -507,7 +856,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.2
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.1
- concurrencyCache := stubConcurrencyCache{
+ concurrencyCache := schedulerTestConcurrencyCache{
loadMap: map[int64]*AccountLoadInfo{
3001: {AccountID: 3001, LoadRate: 95, WaitingCount: 8},
3002: {AccountID: 3002, LoadRate: 20, WaitingCount: 1},
@@ -520,9 +869,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback
}
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: accounts},
- cache: &stubGatewayCache{},
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
+ cache: &schedulerTestGatewayCache{},
cfg: cfg,
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(concurrencyCache),
}
@@ -559,16 +909,17 @@ func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics(t *testing.T) {
Schedulable: true,
Concurrency: 1,
}
- cache := &stubGatewayCache{
+ cache := &schedulerTestGatewayCache{
sessionBindings: map[string]int64{
"openai:session_hash_metrics": account.ID,
},
}
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}},
cache: cache,
cfg: &config.Config{},
- concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, _, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_metrics", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
@@ -749,7 +1100,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesA
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 1
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 1
- concurrencyCache := stubConcurrencyCache{
+ concurrencyCache := schedulerTestConcurrencyCache{
loadMap: map[int64]*AccountLoadInfo{
5101: {AccountID: 5101, LoadRate: 20, WaitingCount: 1},
5102: {AccountID: 5102, LoadRate: 20, WaitingCount: 1},
@@ -757,9 +1108,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesA
},
}
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: accounts},
- cache: &stubGatewayCache{sessionBindings: map[string]int64{}},
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
+ cache: &schedulerTestGatewayCache{sessionBindings: map[string]int64{}},
cfg: cfg,
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(concurrencyCache),
}
@@ -905,12 +1257,14 @@ func TestDefaultOpenAIAccountScheduler_ReportSwitchAndSnapshot(t *testing.T) {
}
func TestOpenAIGatewayService_SchedulerWrappersAndDefaults(t *testing.T) {
+ resetOpenAIAdvancedSchedulerSettingCacheForTest()
+
svc := &OpenAIGatewayService{}
ttft := 120
svc.ReportOpenAIAccountScheduleResult(10, true, &ttft)
svc.RecordOpenAIAccountSwitch()
snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics()
- require.GreaterOrEqual(t, snapshot.AccountSwitchTotal, int64(1))
+ require.Equal(t, OpenAIAccountSchedulerMetricsSnapshot{}, snapshot)
require.Equal(t, 7, svc.openAIWSLBTopK())
require.Equal(t, openaiStickySessionTTL, svc.openAIWSSessionStickyTTL())
@@ -947,7 +1301,7 @@ func TestDefaultOpenAIAccountScheduler_IsAccountTransportCompatible_Branches(t *
require.True(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportHTTPSSE))
require.False(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportResponsesWebsocketV2))
- cfg := newOpenAIWSV2TestConfig()
+ cfg := newSchedulerTestOpenAIWSV2Config()
scheduler.service = &OpenAIGatewayService{cfg: cfg}
account := &Account{
ID: 8801,
diff --git a/backend/internal/service/openai_account_scheduler_ws_snapshot_test.go b/backend/internal/service/openai_account_scheduler_ws_snapshot_test.go
index c5de8203..ddafc6eb 100644
--- a/backend/internal/service/openai_account_scheduler_ws_snapshot_test.go
+++ b/backend/internal/service/openai_account_scheduler_ws_snapshot_test.go
@@ -38,11 +38,12 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_UsesWSPassthroughSnapsh
cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeCtxPool
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: []Account{*account}},
- cache: &stubGatewayCache{},
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{*account}},
+ cache: &schedulerTestGatewayCache{},
cfg: cfg,
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
schedulerSnapshot: &SchedulerSnapshotService{cache: snapshotCache},
- concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(
diff --git a/backend/internal/service/payment_config_service.go b/backend/internal/service/payment_config_service.go
index 59764b29..34462a3a 100644
--- a/backend/internal/service/payment_config_service.go
+++ b/backend/internal/service/payment_config_service.go
@@ -196,12 +196,25 @@ func (s *PaymentConfigService) GetPaymentConfig(ctx context.Context) (*PaymentCo
SettingHelpImageURL, SettingHelpText,
SettingCancelRateLimitOn, SettingCancelRateLimitMax,
SettingCancelWindowSize, SettingCancelWindowUnit, SettingCancelWindowMode,
+ SettingPaymentVisibleMethodAlipayEnabled, SettingPaymentVisibleMethodAlipaySource,
+ SettingPaymentVisibleMethodWxpayEnabled, SettingPaymentVisibleMethodWxpaySource,
}
vals, err := s.settingRepo.GetMultiple(ctx, keys)
if err != nil {
return nil, fmt.Errorf("get payment config settings: %w", err)
}
cfg := s.parsePaymentConfig(vals)
+ if s.entClient != nil {
+ instances, err := s.entClient.PaymentProviderInstance.Query().
+ Where(paymentproviderinstance.EnabledEQ(true)).
+ All(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("list enabled provider instances: %w", err)
+ }
+ cfg.EnabledTypes = applyVisibleMethodRoutingToEnabledTypes(cfg.EnabledTypes, vals, buildVisibleMethodSourceAvailability(instances))
+ } else {
+ cfg.EnabledTypes = applyVisibleMethodRoutingToEnabledTypes(cfg.EnabledTypes, vals, nil)
+ }
// Load Stripe publishable key from the first enabled Stripe provider instance
cfg.StripePublishableKey = s.getStripePublishableKey(ctx)
return cfg, nil
@@ -234,18 +247,23 @@ func (s *PaymentConfigService) parsePaymentConfig(vals map[string]string) *Payme
cfg.LoadBalanceStrategy = payment.DefaultLoadBalanceStrategy
}
if raw := vals[SettingEnabledPaymentTypes]; raw != "" {
+ types := make([]string, 0, len(strings.Split(raw, ",")))
for _, t := range strings.Split(raw, ",") {
t = strings.TrimSpace(t)
if t != "" {
- cfg.EnabledTypes = append(cfg.EnabledTypes, t)
+ types = append(types, t)
}
}
+ cfg.EnabledTypes = NormalizeVisibleMethods(types)
}
return cfg
}
// getStripePublishableKey finds the publishable key from the first enabled Stripe provider instance.
func (s *PaymentConfigService) getStripePublishableKey(ctx context.Context) string {
+ if s.entClient == nil {
+ return ""
+ }
instances, err := s.entClient.PaymentProviderInstance.Query().
Where(
paymentproviderinstance.EnabledEQ(true),
@@ -385,3 +403,79 @@ func pcParseInt(s string, defaultVal int) int {
}
return v
}
+
+func buildVisibleMethodSourceAvailability(instances []*dbent.PaymentProviderInstance) map[string]bool {
+ available := make(map[string]bool, 4)
+ for _, inst := range instances {
+ switch inst.ProviderKey {
+ case payment.TypeAlipay:
+ if inst.SupportedTypes == "" || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeAlipay) || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeAlipayDirect) {
+ available[VisibleMethodSourceOfficialAlipay] = true
+ }
+ case payment.TypeWxpay:
+ if inst.SupportedTypes == "" || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeWxpay) || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeWxpayDirect) {
+ available[VisibleMethodSourceOfficialWechat] = true
+ }
+ case payment.TypeEasyPay:
+ for _, supportedType := range splitTypes(inst.SupportedTypes) {
+ switch NormalizeVisibleMethod(supportedType) {
+ case payment.TypeAlipay:
+ available[VisibleMethodSourceEasyPayAlipay] = true
+ case payment.TypeWxpay:
+ available[VisibleMethodSourceEasyPayWechat] = true
+ }
+ }
+ }
+ }
+ return available
+}
+
+func applyVisibleMethodRoutingToEnabledTypes(base []string, vals map[string]string, available map[string]bool) []string {
+ shouldExpose := map[string]bool{
+ payment.TypeAlipay: visibleMethodShouldBeExposed(payment.TypeAlipay, vals, available),
+ payment.TypeWxpay: visibleMethodShouldBeExposed(payment.TypeWxpay, vals, available),
+ }
+
+ seen := make(map[string]struct{}, len(base)+2)
+ out := make([]string, 0, len(base)+2)
+ appendType := func(paymentType string) {
+ paymentType = NormalizeVisibleMethod(paymentType)
+ if paymentType == "" {
+ return
+ }
+ if _, ok := seen[paymentType]; ok {
+ return
+ }
+ seen[paymentType] = struct{}{}
+ out = append(out, paymentType)
+ }
+
+ for _, paymentType := range base {
+ visibleMethod := NormalizeVisibleMethod(paymentType)
+ switch visibleMethod {
+ case payment.TypeAlipay, payment.TypeWxpay:
+ if shouldExpose[visibleMethod] {
+ appendType(visibleMethod)
+ }
+ default:
+ appendType(visibleMethod)
+ }
+ }
+
+ for _, visibleMethod := range []string{payment.TypeAlipay, payment.TypeWxpay} {
+ if shouldExpose[visibleMethod] {
+ appendType(visibleMethod)
+ }
+ }
+ return out
+}
+
+func visibleMethodShouldBeExposed(method string, vals map[string]string, available map[string]bool) bool {
+ enabledKey := visibleMethodEnabledSettingKey(method)
+ sourceKey := visibleMethodSourceSettingKey(method)
+ if enabledKey == "" || sourceKey == "" || vals[enabledKey] != "true" {
+ return false
+ }
+ source := NormalizeVisibleMethodSource(method, vals[sourceKey])
+ return source != "" && available[source]
+}
diff --git a/backend/internal/service/payment_config_service_test.go b/backend/internal/service/payment_config_service_test.go
index 027bb796..10919058 100644
--- a/backend/internal/service/payment_config_service_test.go
+++ b/backend/internal/service/payment_config_service_test.go
@@ -1,9 +1,17 @@
package service
import (
+ "context"
+ "database/sql"
"testing"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/internal/payment"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
)
func TestPcParseFloat(t *testing.T) {
@@ -163,6 +171,20 @@ func TestParsePaymentConfig(t *testing.T) {
}
})
+ t.Run("enabled types are normalized to visible methods and deduplicated", func(t *testing.T) {
+ t.Parallel()
+ vals := map[string]string{
+ SettingEnabledPaymentTypes: "alipay_direct, alipay, wxpay_direct, wxpay",
+ }
+ cfg := svc.parsePaymentConfig(vals)
+ if len(cfg.EnabledTypes) != 2 {
+ t.Fatalf("EnabledTypes len = %d, want 2", len(cfg.EnabledTypes))
+ }
+ if cfg.EnabledTypes[0] != "alipay" || cfg.EnabledTypes[1] != "wxpay" {
+ t.Fatalf("EnabledTypes = %v, want [alipay wxpay]", cfg.EnabledTypes)
+ }
+ })
+
t.Run("empty enabled types string", func(t *testing.T) {
t.Parallel()
vals := map[string]string{
@@ -204,3 +226,167 @@ func TestGetBasePaymentType(t *testing.T) {
})
}
}
+
+func TestApplyVisibleMethodRoutingToEnabledTypes(t *testing.T) {
+ t.Parallel()
+
+ base := []string{"alipay", "wxpay", "stripe"}
+ vals := map[string]string{
+ SettingPaymentVisibleMethodAlipayEnabled: "true",
+ SettingPaymentVisibleMethodAlipaySource: VisibleMethodSourceOfficialAlipay,
+ SettingPaymentVisibleMethodWxpayEnabled: "true",
+ SettingPaymentVisibleMethodWxpaySource: VisibleMethodSourceOfficialWechat,
+ }
+ available := map[string]bool{
+ VisibleMethodSourceOfficialAlipay: true,
+ VisibleMethodSourceOfficialWechat: false,
+ }
+
+ got := applyVisibleMethodRoutingToEnabledTypes(base, vals, available)
+ want := []string{"alipay", "stripe"}
+ if len(got) != len(want) {
+ t.Fatalf("applyVisibleMethodRoutingToEnabledTypes len = %d, want %d (%v)", len(got), len(want), got)
+ }
+ for i := range want {
+ if got[i] != want[i] {
+ t.Fatalf("applyVisibleMethodRoutingToEnabledTypes[%d] = %q, want %q (full=%v)", i, got[i], want[i], got)
+ }
+ }
+}
+
+func TestApplyVisibleMethodRoutingAddsConfiguredVisibleMethod(t *testing.T) {
+ t.Parallel()
+
+ base := []string{"stripe"}
+ vals := map[string]string{
+ SettingPaymentVisibleMethodAlipayEnabled: "true",
+ SettingPaymentVisibleMethodAlipaySource: VisibleMethodSourceEasyPayAlipay,
+ }
+ available := map[string]bool{
+ VisibleMethodSourceEasyPayAlipay: true,
+ }
+
+ got := applyVisibleMethodRoutingToEnabledTypes(base, vals, available)
+ want := []string{"stripe", "alipay"}
+ if len(got) != len(want) {
+ t.Fatalf("applyVisibleMethodRoutingToEnabledTypes len = %d, want %d (%v)", len(got), len(want), got)
+ }
+ for i := range want {
+ if got[i] != want[i] {
+ t.Fatalf("applyVisibleMethodRoutingToEnabledTypes[%d] = %q, want %q (full=%v)", i, got[i], want[i], got)
+ }
+ }
+}
+
+func TestBuildVisibleMethodSourceAvailability(t *testing.T) {
+ t.Parallel()
+
+ instances := []*dbent.PaymentProviderInstance{
+ {ProviderKey: payment.TypeAlipay, SupportedTypes: "alipay"},
+ {ProviderKey: payment.TypeEasyPay, SupportedTypes: "wxpay_direct, alipay"},
+ {ProviderKey: payment.TypeWxpay, SupportedTypes: "wxpay_direct"},
+ }
+
+ got := buildVisibleMethodSourceAvailability(instances)
+ if !got[VisibleMethodSourceOfficialAlipay] {
+ t.Fatalf("expected %q to be available", VisibleMethodSourceOfficialAlipay)
+ }
+ if !got[VisibleMethodSourceEasyPayAlipay] {
+ t.Fatalf("expected %q to be available", VisibleMethodSourceEasyPayAlipay)
+ }
+ if !got[VisibleMethodSourceOfficialWechat] {
+ t.Fatalf("expected %q to be available", VisibleMethodSourceOfficialWechat)
+ }
+ if !got[VisibleMethodSourceEasyPayWechat] {
+ t.Fatalf("expected %q to be available", VisibleMethodSourceEasyPayWechat)
+ }
+}
+
+func TestGetPaymentConfigAppliesVisibleMethodRouting(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+
+ _, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeEasyPay).
+ SetName("EasyPay Alipay").
+ SetConfig("{}").
+ SetSupportedTypes("alipay").
+ SetEnabled(true).
+ Save(ctx)
+ if err != nil {
+ t.Fatalf("create easypay instance: %v", err)
+ }
+
+ svc := &PaymentConfigService{
+ entClient: client,
+ settingRepo: &paymentConfigSettingRepoStub{
+ values: map[string]string{
+ SettingEnabledPaymentTypes: "alipay,wxpay,stripe",
+ SettingPaymentVisibleMethodAlipayEnabled: "true",
+ SettingPaymentVisibleMethodAlipaySource: "easypay",
+ SettingPaymentVisibleMethodWxpayEnabled: "true",
+ SettingPaymentVisibleMethodWxpaySource: "wxpay",
+ },
+ },
+ }
+
+ cfg, err := svc.GetPaymentConfig(ctx)
+ if err != nil {
+ t.Fatalf("GetPaymentConfig returned error: %v", err)
+ }
+
+ want := []string{payment.TypeAlipay, payment.TypeStripe}
+ if len(cfg.EnabledTypes) != len(want) {
+ t.Fatalf("EnabledTypes len = %d, want %d (%v)", len(cfg.EnabledTypes), len(want), cfg.EnabledTypes)
+ }
+ for i := range want {
+ if cfg.EnabledTypes[i] != want[i] {
+ t.Fatalf("EnabledTypes[%d] = %q, want %q (full=%v)", i, cfg.EnabledTypes[i], want[i], cfg.EnabledTypes)
+ }
+ }
+}
+
+func newPaymentConfigServiceTestClient(t *testing.T) *dbent.Client {
+ t.Helper()
+
+ db, err := sql.Open("sqlite", "file:payment_config_service?mode=memory&cache=shared")
+ if err != nil {
+ t.Fatalf("open sqlite: %v", err)
+ }
+ t.Cleanup(func() { _ = db.Close() })
+
+ if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil {
+ t.Fatalf("enable foreign keys: %v", err)
+ }
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+ t.Cleanup(func() { _ = client.Close() })
+ return client
+}
+
+type paymentConfigSettingRepoStub struct {
+ values map[string]string
+}
+
+func (s *paymentConfigSettingRepoStub) Get(context.Context, string) (*Setting, error) {
+ return nil, nil
+}
+func (s *paymentConfigSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
+ return s.values[key], nil
+}
+func (s *paymentConfigSettingRepoStub) Set(context.Context, string, string) error { return nil }
+func (s *paymentConfigSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
+ out := make(map[string]string, len(keys))
+ for _, key := range keys {
+ out[key] = s.values[key]
+ }
+ return out, nil
+}
+func (s *paymentConfigSettingRepoStub) SetMultiple(context.Context, map[string]string) error {
+ return nil
+}
+func (s *paymentConfigSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
+ return s.values, nil
+}
+func (s *paymentConfigSettingRepoStub) Delete(context.Context, string) error { return nil }
diff --git a/backend/internal/service/payment_resume_service.go b/backend/internal/service/payment_resume_service.go
new file mode 100644
index 00000000..894a8198
--- /dev/null
+++ b/backend/internal/service/payment_resume_service.go
@@ -0,0 +1,248 @@
+package service
+
+import (
+ "context"
+ "crypto/hmac"
+ "crypto/sha256"
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "net/url"
+ "strings"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+)
+
+const (
+ PaymentSourceHostedRedirect = "hosted_redirect"
+ PaymentSourceWechatInAppResume = "wechat_in_app_resume"
+
+ paymentResumeFallbackSigningKey = "sub2api-payment-resume"
+
+ SettingPaymentVisibleMethodAlipaySource = "payment_visible_method_alipay_source"
+ SettingPaymentVisibleMethodWxpaySource = "payment_visible_method_wxpay_source"
+ SettingPaymentVisibleMethodAlipayEnabled = "payment_visible_method_alipay_enabled"
+ SettingPaymentVisibleMethodWxpayEnabled = "payment_visible_method_wxpay_enabled"
+
+ VisibleMethodSourceOfficialAlipay = "official_alipay"
+ VisibleMethodSourceEasyPayAlipay = "easypay_alipay"
+ VisibleMethodSourceOfficialWechat = "official_wxpay"
+ VisibleMethodSourceEasyPayWechat = "easypay_wxpay"
+)
+
+type ResumeTokenClaims struct {
+ OrderID int64 `json:"oid"`
+ UserID int64 `json:"uid,omitempty"`
+ ProviderInstanceID string `json:"pi,omitempty"`
+ ProviderKey string `json:"pk,omitempty"`
+ PaymentType string `json:"pt,omitempty"`
+ CanonicalReturnURL string `json:"ru,omitempty"`
+ IssuedAt int64 `json:"iat"`
+}
+
+type PaymentResumeService struct {
+ signingKey []byte
+}
+
+type visibleMethodLoadBalancer struct {
+ inner payment.LoadBalancer
+ configService *PaymentConfigService
+}
+
+func NewPaymentResumeService(signingKey []byte) *PaymentResumeService {
+ return &PaymentResumeService{signingKey: signingKey}
+}
+
+func NormalizeVisibleMethod(method string) string {
+ return payment.GetBasePaymentType(strings.TrimSpace(method))
+}
+
+func NormalizeVisibleMethods(methods []string) []string {
+ if len(methods) == 0 {
+ return nil
+ }
+ seen := make(map[string]struct{}, len(methods))
+ out := make([]string, 0, len(methods))
+ for _, method := range methods {
+ normalized := NormalizeVisibleMethod(method)
+ if normalized == "" {
+ continue
+ }
+ if _, ok := seen[normalized]; ok {
+ continue
+ }
+ seen[normalized] = struct{}{}
+ out = append(out, normalized)
+ }
+ return out
+}
+
+func NormalizePaymentSource(source string) string {
+ switch strings.TrimSpace(strings.ToLower(source)) {
+ case "", PaymentSourceHostedRedirect:
+ return PaymentSourceHostedRedirect
+ case "wechat_in_app", "wxpay_resume", PaymentSourceWechatInAppResume:
+ return PaymentSourceWechatInAppResume
+ default:
+ return strings.TrimSpace(strings.ToLower(source))
+ }
+}
+
+func NormalizeVisibleMethodSource(method, source string) string {
+ switch NormalizeVisibleMethod(method) {
+ case payment.TypeAlipay:
+ switch strings.TrimSpace(strings.ToLower(source)) {
+ case VisibleMethodSourceOfficialAlipay, payment.TypeAlipay, payment.TypeAlipayDirect, "official":
+ return VisibleMethodSourceOfficialAlipay
+ case VisibleMethodSourceEasyPayAlipay, payment.TypeEasyPay:
+ return VisibleMethodSourceEasyPayAlipay
+ }
+ case payment.TypeWxpay:
+ switch strings.TrimSpace(strings.ToLower(source)) {
+ case VisibleMethodSourceOfficialWechat, payment.TypeWxpay, payment.TypeWxpayDirect, "wechat", "official":
+ return VisibleMethodSourceOfficialWechat
+ case VisibleMethodSourceEasyPayWechat, payment.TypeEasyPay:
+ return VisibleMethodSourceEasyPayWechat
+ }
+ }
+ return ""
+}
+
+func VisibleMethodProviderKeyForSource(method, source string) (string, bool) {
+ switch NormalizeVisibleMethodSource(method, source) {
+ case VisibleMethodSourceOfficialAlipay:
+ return payment.TypeAlipay, NormalizeVisibleMethod(method) == payment.TypeAlipay
+ case VisibleMethodSourceEasyPayAlipay:
+ return payment.TypeEasyPay, NormalizeVisibleMethod(method) == payment.TypeAlipay
+ case VisibleMethodSourceOfficialWechat:
+ return payment.TypeWxpay, NormalizeVisibleMethod(method) == payment.TypeWxpay
+ case VisibleMethodSourceEasyPayWechat:
+ return payment.TypeEasyPay, NormalizeVisibleMethod(method) == payment.TypeWxpay
+ default:
+ return "", false
+ }
+}
+
+func newVisibleMethodLoadBalancer(inner payment.LoadBalancer, configService *PaymentConfigService) payment.LoadBalancer {
+ if inner == nil || configService == nil || configService.settingRepo == nil {
+ return inner
+ }
+ return &visibleMethodLoadBalancer{inner: inner, configService: configService}
+}
+
+func (lb *visibleMethodLoadBalancer) GetInstanceConfig(ctx context.Context, instanceID int64) (map[string]string, error) {
+ return lb.inner.GetInstanceConfig(ctx, instanceID)
+}
+
+func (lb *visibleMethodLoadBalancer) SelectInstance(ctx context.Context, providerKey string, paymentType payment.PaymentType, strategy payment.Strategy, orderAmount float64) (*payment.InstanceSelection, error) {
+ visibleMethod := NormalizeVisibleMethod(paymentType)
+ if providerKey != "" || (visibleMethod != payment.TypeAlipay && visibleMethod != payment.TypeWxpay) {
+ return lb.inner.SelectInstance(ctx, providerKey, paymentType, strategy, orderAmount)
+ }
+
+ enabledKey := visibleMethodEnabledSettingKey(visibleMethod)
+ sourceKey := visibleMethodSourceSettingKey(visibleMethod)
+ vals, err := lb.configService.settingRepo.GetMultiple(ctx, []string{enabledKey, sourceKey})
+ if err != nil {
+ return nil, fmt.Errorf("load visible method routing for %s: %w", visibleMethod, err)
+ }
+ if vals[enabledKey] != "true" {
+ return nil, fmt.Errorf("visible payment method %s is disabled", visibleMethod)
+ }
+
+ targetProviderKey, ok := VisibleMethodProviderKeyForSource(visibleMethod, vals[sourceKey])
+ if !ok {
+ return nil, fmt.Errorf("visible payment method %s has no valid source", visibleMethod)
+ }
+ return lb.inner.SelectInstance(ctx, targetProviderKey, paymentType, strategy, orderAmount)
+}
+
+func visibleMethodEnabledSettingKey(method string) string {
+ switch NormalizeVisibleMethod(method) {
+ case payment.TypeAlipay:
+ return SettingPaymentVisibleMethodAlipayEnabled
+ case payment.TypeWxpay:
+ return SettingPaymentVisibleMethodWxpayEnabled
+ default:
+ return ""
+ }
+}
+
+func visibleMethodSourceSettingKey(method string) string {
+ switch NormalizeVisibleMethod(method) {
+ case payment.TypeAlipay:
+ return SettingPaymentVisibleMethodAlipaySource
+ case payment.TypeWxpay:
+ return SettingPaymentVisibleMethodWxpaySource
+ default:
+ return ""
+ }
+}
+
+func CanonicalizeReturnURL(raw string) (string, error) {
+ raw = strings.TrimSpace(raw)
+ if raw == "" {
+ return "", nil
+ }
+ parsed, err := url.Parse(raw)
+ if err != nil || !parsed.IsAbs() || parsed.Host == "" {
+ return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must be an absolute http/https URL")
+ }
+ if parsed.Scheme != "http" && parsed.Scheme != "https" {
+ return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must use http or https")
+ }
+ parsed.Fragment = ""
+ if parsed.Path == "" {
+ parsed.Path = "/"
+ }
+ return parsed.String(), nil
+}
+
+func (s *PaymentResumeService) CreateToken(claims ResumeTokenClaims) (string, error) {
+ if claims.OrderID <= 0 {
+ return "", fmt.Errorf("resume token requires order id")
+ }
+ if claims.IssuedAt == 0 {
+ claims.IssuedAt = time.Now().Unix()
+ }
+ payload, err := json.Marshal(claims)
+ if err != nil {
+ return "", fmt.Errorf("marshal resume claims: %w", err)
+ }
+ encodedPayload := base64.RawURLEncoding.EncodeToString(payload)
+ return encodedPayload + "." + s.sign(encodedPayload), nil
+}
+
+func (s *PaymentResumeService) ParseToken(token string) (*ResumeTokenClaims, error) {
+ parts := strings.Split(token, ".")
+ if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
+ return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token is malformed")
+ }
+ if !hmac.Equal([]byte(parts[1]), []byte(s.sign(parts[0]))) {
+ return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token signature mismatch")
+ }
+ payload, err := base64.RawURLEncoding.DecodeString(parts[0])
+ if err != nil {
+ return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token payload is malformed")
+ }
+ var claims ResumeTokenClaims
+ if err := json.Unmarshal(payload, &claims); err != nil {
+ return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token payload is invalid")
+ }
+ if claims.OrderID <= 0 {
+ return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token missing order id")
+ }
+ return &claims, nil
+}
+
+func (s *PaymentResumeService) sign(payload string) string {
+ key := s.signingKey
+ if len(key) == 0 {
+ key = []byte(paymentResumeFallbackSigningKey)
+ }
+ mac := hmac.New(sha256.New, key)
+ _, _ = mac.Write([]byte(payload))
+ return base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
+}
diff --git a/backend/internal/service/payment_resume_service_test.go b/backend/internal/service/payment_resume_service_test.go
new file mode 100644
index 00000000..e56b4a88
--- /dev/null
+++ b/backend/internal/service/payment_resume_service_test.go
@@ -0,0 +1,240 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+)
+
+func TestNormalizeVisibleMethods(t *testing.T) {
+ t.Parallel()
+
+ got := NormalizeVisibleMethods([]string{
+ "alipay_direct",
+ "alipay",
+ " wxpay_direct ",
+ "wxpay",
+ "stripe",
+ })
+
+ want := []string{"alipay", "wxpay", "stripe"}
+ if len(got) != len(want) {
+ t.Fatalf("NormalizeVisibleMethods len = %d, want %d (%v)", len(got), len(want), got)
+ }
+ for i := range want {
+ if got[i] != want[i] {
+ t.Fatalf("NormalizeVisibleMethods[%d] = %q, want %q (full=%v)", i, got[i], want[i], got)
+ }
+ }
+}
+
+func TestNormalizePaymentSource(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ input string
+ expect string
+ }{
+ {name: "empty uses default", input: "", expect: PaymentSourceHostedRedirect},
+ {name: "wechat alias normalized", input: "wechat_in_app", expect: PaymentSourceWechatInAppResume},
+ {name: "canonical value preserved", input: PaymentSourceWechatInAppResume, expect: PaymentSourceWechatInAppResume},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ if got := NormalizePaymentSource(tt.input); got != tt.expect {
+ t.Fatalf("NormalizePaymentSource(%q) = %q, want %q", tt.input, got, tt.expect)
+ }
+ })
+ }
+}
+
+func TestCanonicalizeReturnURL(t *testing.T) {
+ t.Parallel()
+
+ got, err := CanonicalizeReturnURL("https://example.com/pay/result?b=2#a")
+ if err != nil {
+ t.Fatalf("CanonicalizeReturnURL returned error: %v", err)
+ }
+ if got != "https://example.com/pay/result?b=2" {
+ t.Fatalf("CanonicalizeReturnURL = %q, want %q", got, "https://example.com/pay/result?b=2")
+ }
+}
+
+func TestCanonicalizeReturnURLRejectsRelativeURL(t *testing.T) {
+ t.Parallel()
+
+ if _, err := CanonicalizeReturnURL("/payment/result"); err == nil {
+ t.Fatal("CanonicalizeReturnURL should reject relative URLs")
+ }
+}
+
+func TestPaymentResumeTokenRoundTrip(t *testing.T) {
+ t.Parallel()
+
+ svc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
+ token, err := svc.CreateToken(ResumeTokenClaims{
+ OrderID: 42,
+ UserID: 7,
+ ProviderInstanceID: "19",
+ ProviderKey: "easypay",
+ PaymentType: "wxpay",
+ CanonicalReturnURL: "https://example.com/payment/result",
+ IssuedAt: 1234567890,
+ })
+ if err != nil {
+ t.Fatalf("CreateToken returned error: %v", err)
+ }
+
+ claims, err := svc.ParseToken(token)
+ if err != nil {
+ t.Fatalf("ParseToken returned error: %v", err)
+ }
+ if claims.OrderID != 42 || claims.UserID != 7 {
+ t.Fatalf("claims mismatch: %+v", claims)
+ }
+ if claims.ProviderInstanceID != "19" || claims.ProviderKey != "easypay" || claims.PaymentType != "wxpay" {
+ t.Fatalf("claims provider snapshot mismatch: %+v", claims)
+ }
+ if claims.CanonicalReturnURL != "https://example.com/payment/result" {
+ t.Fatalf("claims return URL = %q", claims.CanonicalReturnURL)
+ }
+}
+
+func TestNormalizeVisibleMethodSource(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ method string
+ input string
+ want string
+ }{
+ {name: "alipay official alias", method: payment.TypeAlipay, input: "alipay", want: VisibleMethodSourceOfficialAlipay},
+ {name: "alipay easypay alias", method: payment.TypeAlipay, input: "easypay", want: VisibleMethodSourceEasyPayAlipay},
+ {name: "wxpay official alias", method: payment.TypeWxpay, input: "wxpay", want: VisibleMethodSourceOfficialWechat},
+ {name: "wxpay easypay alias", method: payment.TypeWxpay, input: "easypay", want: VisibleMethodSourceEasyPayWechat},
+ {name: "unsupported source", method: payment.TypeWxpay, input: "stripe", want: ""},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ if got := NormalizeVisibleMethodSource(tt.method, tt.input); got != tt.want {
+ t.Fatalf("NormalizeVisibleMethodSource(%q, %q) = %q, want %q", tt.method, tt.input, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestVisibleMethodProviderKeyForSource(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ method string
+ source string
+ want string
+ ok bool
+ }{
+ {name: "official alipay", method: payment.TypeAlipay, source: VisibleMethodSourceOfficialAlipay, want: payment.TypeAlipay, ok: true},
+ {name: "easypay alipay", method: payment.TypeAlipay, source: VisibleMethodSourceEasyPayAlipay, want: payment.TypeEasyPay, ok: true},
+ {name: "official wechat", method: payment.TypeWxpay, source: VisibleMethodSourceOfficialWechat, want: payment.TypeWxpay, ok: true},
+ {name: "easypay wechat", method: payment.TypeWxpay, source: VisibleMethodSourceEasyPayWechat, want: payment.TypeEasyPay, ok: true},
+ {name: "mismatched method and source", method: payment.TypeAlipay, source: VisibleMethodSourceOfficialWechat, want: "", ok: false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ got, ok := VisibleMethodProviderKeyForSource(tt.method, tt.source)
+ if got != tt.want || ok != tt.ok {
+ t.Fatalf("VisibleMethodProviderKeyForSource(%q, %q) = (%q, %v), want (%q, %v)", tt.method, tt.source, got, ok, tt.want, tt.ok)
+ }
+ })
+ }
+}
+
+func TestVisibleMethodLoadBalancerUsesConfiguredSource(t *testing.T) {
+ t.Parallel()
+
+ inner := &captureLoadBalancer{}
+ configService := &PaymentConfigService{
+ settingRepo: &paymentSettingRepoStub{
+ values: map[string]string{
+ SettingPaymentVisibleMethodAlipayEnabled: "true",
+ SettingPaymentVisibleMethodAlipaySource: VisibleMethodSourceOfficialAlipay,
+ },
+ },
+ }
+ lb := newVisibleMethodLoadBalancer(inner, configService)
+
+ _, err := lb.SelectInstance(context.Background(), "", payment.TypeAlipay, payment.StrategyRoundRobin, 12.5)
+ if err != nil {
+ t.Fatalf("SelectInstance returned error: %v", err)
+ }
+ if inner.lastProviderKey != payment.TypeAlipay {
+ t.Fatalf("lastProviderKey = %q, want %q", inner.lastProviderKey, payment.TypeAlipay)
+ }
+}
+
+func TestVisibleMethodLoadBalancerRejectsDisabledVisibleMethod(t *testing.T) {
+ t.Parallel()
+
+ inner := &captureLoadBalancer{}
+ configService := &PaymentConfigService{
+ settingRepo: &paymentSettingRepoStub{
+ values: map[string]string{
+ SettingPaymentVisibleMethodWxpayEnabled: "false",
+ SettingPaymentVisibleMethodWxpaySource: VisibleMethodSourceOfficialWechat,
+ },
+ },
+ }
+ lb := newVisibleMethodLoadBalancer(inner, configService)
+
+ if _, err := lb.SelectInstance(context.Background(), "", payment.TypeWxpay, payment.StrategyRoundRobin, 9.9); err == nil {
+ t.Fatal("SelectInstance should reject disabled visible method")
+ }
+}
+
+type paymentSettingRepoStub struct {
+ values map[string]string
+}
+
+func (s *paymentSettingRepoStub) Get(context.Context, string) (*Setting, error) { return nil, nil }
+func (s *paymentSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
+ return s.values[key], nil
+}
+func (s *paymentSettingRepoStub) Set(context.Context, string, string) error { return nil }
+func (s *paymentSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
+ out := make(map[string]string, len(keys))
+ for _, key := range keys {
+ out[key] = s.values[key]
+ }
+ return out, nil
+}
+func (s *paymentSettingRepoStub) SetMultiple(context.Context, map[string]string) error { return nil }
+func (s *paymentSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
+ return s.values, nil
+}
+func (s *paymentSettingRepoStub) Delete(context.Context, string) error { return nil }
+
+type captureLoadBalancer struct {
+ lastProviderKey string
+ lastPaymentType string
+}
+
+func (c *captureLoadBalancer) GetInstanceConfig(context.Context, int64) (map[string]string, error) {
+ return map[string]string{}, nil
+}
+
+func (c *captureLoadBalancer) SelectInstance(_ context.Context, providerKey string, paymentType payment.PaymentType, _ payment.Strategy, _ float64) (*payment.InstanceSelection, error) {
+ c.lastProviderKey = providerKey
+ c.lastPaymentType = paymentType
+ return &payment.InstanceSelection{ProviderKey: providerKey, SupportedTypes: paymentType}, nil
+}
diff --git a/backend/internal/service/payment_service.go b/backend/internal/service/payment_service.go
index 6fc23f97..e897741a 100644
--- a/backend/internal/service/payment_service.go
+++ b/backend/internal/service/payment_service.go
@@ -65,15 +65,17 @@ func generateRandomString(n int) string {
}
type CreateOrderRequest struct {
- UserID int64
- Amount float64
- PaymentType string
- ClientIP string
- IsMobile bool
- SrcHost string
- SrcURL string
- OrderType string
- PlanID int64
+ UserID int64
+ Amount float64
+ PaymentType string
+ ClientIP string
+ IsMobile bool
+ SrcHost string
+ SrcURL string
+ ReturnURL string
+ PaymentSource string
+ OrderType string
+ PlanID int64
}
type CreateOrderResponse struct {
@@ -88,6 +90,7 @@ type CreateOrderResponse struct {
ClientSecret string `json:"client_secret,omitempty"`
ExpiresAt time.Time `json:"expires_at"`
PaymentMode string `json:"payment_mode,omitempty"`
+ ResumeToken string `json:"resume_token,omitempty"`
}
type OrderListParams struct {
@@ -165,10 +168,13 @@ type PaymentService struct {
configService *PaymentConfigService
userRepo UserRepository
groupRepo GroupRepository
+ resumeService *PaymentResumeService
}
func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository) *PaymentService {
- return &PaymentService{entClient: entClient, registry: registry, loadBalancer: loadBalancer, redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo}
+ svc := &PaymentService{entClient: entClient, registry: registry, loadBalancer: newVisibleMethodLoadBalancer(loadBalancer, configService), redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo}
+ svc.resumeService = NewPaymentResumeService(psResumeSigningKey(configService))
+ return svc
}
// --- Provider Registry ---
@@ -262,6 +268,20 @@ func psNilIfEmpty(s string) *string {
return &s
}
+func (s *PaymentService) paymentResume() *PaymentResumeService {
+ if s.resumeService != nil {
+ return s.resumeService
+ }
+ return NewPaymentResumeService(psResumeSigningKey(s.configService))
+}
+
+func psResumeSigningKey(configService *PaymentConfigService) []byte {
+ if configService == nil {
+ return nil
+ }
+ return configService.encryptionKey
+}
+
func psSliceContains(sl []string, s string) bool {
for _, v := range sl {
if v == s {
diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go
index 7f4a2eb1..de555478 100644
--- a/backend/internal/service/setting_service.go
+++ b/backend/internal/service/setting_service.go
@@ -9,6 +9,7 @@ import (
"fmt"
"log/slog"
"net/url"
+ "os"
"sort"
"strconv"
"strings"
@@ -114,6 +115,66 @@ type SettingService struct {
webSearchManagerBuilder WebSearchManagerBuilder
}
+type ProviderDefaultGrantSettings struct {
+ Balance float64
+ Concurrency int
+ Subscriptions []DefaultSubscriptionSetting
+ GrantOnSignup bool
+ GrantOnFirstBind bool
+}
+
+type AuthSourceDefaultSettings struct {
+ Email ProviderDefaultGrantSettings
+ LinuxDo ProviderDefaultGrantSettings
+ OIDC ProviderDefaultGrantSettings
+ WeChat ProviderDefaultGrantSettings
+ ForceEmailOnThirdPartySignup bool
+}
+
+type authSourceDefaultKeySet struct {
+ balance string
+ concurrency string
+ subscriptions string
+ grantOnSignup string
+ grantOnFirstBind string
+}
+
+var (
+ emailAuthSourceDefaultKeys = authSourceDefaultKeySet{
+ balance: SettingKeyAuthSourceDefaultEmailBalance,
+ concurrency: SettingKeyAuthSourceDefaultEmailConcurrency,
+ subscriptions: SettingKeyAuthSourceDefaultEmailSubscriptions,
+ grantOnSignup: SettingKeyAuthSourceDefaultEmailGrantOnSignup,
+ grantOnFirstBind: SettingKeyAuthSourceDefaultEmailGrantOnFirstBind,
+ }
+ linuxDoAuthSourceDefaultKeys = authSourceDefaultKeySet{
+ balance: SettingKeyAuthSourceDefaultLinuxDoBalance,
+ concurrency: SettingKeyAuthSourceDefaultLinuxDoConcurrency,
+ subscriptions: SettingKeyAuthSourceDefaultLinuxDoSubscriptions,
+ grantOnSignup: SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup,
+ grantOnFirstBind: SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind,
+ }
+ oidcAuthSourceDefaultKeys = authSourceDefaultKeySet{
+ balance: SettingKeyAuthSourceDefaultOIDCBalance,
+ concurrency: SettingKeyAuthSourceDefaultOIDCConcurrency,
+ subscriptions: SettingKeyAuthSourceDefaultOIDCSubscriptions,
+ grantOnSignup: SettingKeyAuthSourceDefaultOIDCGrantOnSignup,
+ grantOnFirstBind: SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind,
+ }
+ weChatAuthSourceDefaultKeys = authSourceDefaultKeySet{
+ balance: SettingKeyAuthSourceDefaultWeChatBalance,
+ concurrency: SettingKeyAuthSourceDefaultWeChatConcurrency,
+ subscriptions: SettingKeyAuthSourceDefaultWeChatSubscriptions,
+ grantOnSignup: SettingKeyAuthSourceDefaultWeChatGrantOnSignup,
+ grantOnFirstBind: SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind,
+ }
+)
+
+const (
+ defaultAuthSourceBalance = 0
+ defaultAuthSourceConcurrency = 5
+)
+
// NewSettingService 创建系统设置服务实例
func NewSettingService(settingRepo SettingRepository, cfg *config.Config) *SettingService {
return &SettingService{
@@ -212,6 +273,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
if oidcProviderName == "" {
oidcProviderName = "OIDC"
}
+ weChatEnabled := isWeChatOAuthConfigured()
// Password reset requires email verification to be enabled
emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true"
@@ -254,6 +316,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
CustomMenuItems: settings[SettingKeyCustomMenuItems],
CustomEndpoints: settings[SettingKeyCustomEndpoints],
LinuxDoOAuthEnabled: linuxDoEnabled,
+ WeChatOAuthEnabled: weChatEnabled,
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
PaymentEnabled: settings[SettingPaymentEnabled] == "true",
OIDCOAuthEnabled: oidcEnabled,
@@ -310,6 +373,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
CustomMenuItems json.RawMessage `json:"custom_menu_items"`
CustomEndpoints json.RawMessage `json:"custom_endpoints"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
+ WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"`
BackendModeEnabled bool `json:"backend_mode_enabled"`
PaymentEnabled bool `json:"payment_enabled"`
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
@@ -344,6 +408,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems),
CustomEndpoints: safeRawJSONArray(settings.CustomEndpoints),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
+ WeChatOAuthEnabled: settings.WeChatOAuthEnabled,
BackendModeEnabled: settings.BackendModeEnabled,
PaymentEnabled: settings.PaymentEnabled,
OIDCOAuthEnabled: settings.OIDCOAuthEnabled,
@@ -392,6 +457,14 @@ func filterUserVisibleMenuItems(raw string) json.RawMessage {
return result
}
+func isWeChatOAuthConfigured() bool {
+ openConfigured := strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_ID")) != "" &&
+ strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_SECRET")) != ""
+ mpConfigured := strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_ID")) != "" &&
+ strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_SECRET")) != ""
+ return openConfigured || mpConfigured
+}
+
// safeRawJSONArray returns raw as json.RawMessage if it's valid JSON, otherwise "[]".
func safeRawJSONArray(raw string) json.RawMessage {
raw = strings.TrimSpace(raw)
@@ -919,6 +992,74 @@ func (s *SettingService) GetDefaultSubscriptions(ctx context.Context) []DefaultS
return parseDefaultSubscriptions(value)
}
+func (s *SettingService) GetAuthSourceDefaultSettings(ctx context.Context) (*AuthSourceDefaultSettings, error) {
+ keys := []string{
+ SettingKeyAuthSourceDefaultEmailBalance,
+ SettingKeyAuthSourceDefaultEmailConcurrency,
+ SettingKeyAuthSourceDefaultEmailSubscriptions,
+ SettingKeyAuthSourceDefaultEmailGrantOnSignup,
+ SettingKeyAuthSourceDefaultEmailGrantOnFirstBind,
+ SettingKeyAuthSourceDefaultLinuxDoBalance,
+ SettingKeyAuthSourceDefaultLinuxDoConcurrency,
+ SettingKeyAuthSourceDefaultLinuxDoSubscriptions,
+ SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup,
+ SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind,
+ SettingKeyAuthSourceDefaultOIDCBalance,
+ SettingKeyAuthSourceDefaultOIDCConcurrency,
+ SettingKeyAuthSourceDefaultOIDCSubscriptions,
+ SettingKeyAuthSourceDefaultOIDCGrantOnSignup,
+ SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind,
+ SettingKeyAuthSourceDefaultWeChatBalance,
+ SettingKeyAuthSourceDefaultWeChatConcurrency,
+ SettingKeyAuthSourceDefaultWeChatSubscriptions,
+ SettingKeyAuthSourceDefaultWeChatGrantOnSignup,
+ SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind,
+ SettingKeyForceEmailOnThirdPartySignup,
+ }
+
+ settings, err := s.settingRepo.GetMultiple(ctx, keys)
+ if err != nil {
+ return nil, fmt.Errorf("get auth source default settings: %w", err)
+ }
+
+ return &AuthSourceDefaultSettings{
+ Email: parseProviderDefaultGrantSettings(settings, emailAuthSourceDefaultKeys),
+ LinuxDo: parseProviderDefaultGrantSettings(settings, linuxDoAuthSourceDefaultKeys),
+ OIDC: parseProviderDefaultGrantSettings(settings, oidcAuthSourceDefaultKeys),
+ WeChat: parseProviderDefaultGrantSettings(settings, weChatAuthSourceDefaultKeys),
+ ForceEmailOnThirdPartySignup: settings[SettingKeyForceEmailOnThirdPartySignup] == "true",
+ }, nil
+}
+
+func (s *SettingService) UpdateAuthSourceDefaultSettings(ctx context.Context, settings *AuthSourceDefaultSettings) error {
+ if settings == nil {
+ return nil
+ }
+
+ for _, subscriptions := range [][]DefaultSubscriptionSetting{
+ settings.Email.Subscriptions,
+ settings.LinuxDo.Subscriptions,
+ settings.OIDC.Subscriptions,
+ settings.WeChat.Subscriptions,
+ } {
+ if err := s.validateDefaultSubscriptionGroups(ctx, subscriptions); err != nil {
+ return err
+ }
+ }
+
+ updates := make(map[string]string, 21)
+ writeProviderDefaultGrantUpdates(updates, emailAuthSourceDefaultKeys, settings.Email)
+ writeProviderDefaultGrantUpdates(updates, linuxDoAuthSourceDefaultKeys, settings.LinuxDo)
+ writeProviderDefaultGrantUpdates(updates, oidcAuthSourceDefaultKeys, settings.OIDC)
+ writeProviderDefaultGrantUpdates(updates, weChatAuthSourceDefaultKeys, settings.WeChat)
+ updates[SettingKeyForceEmailOnThirdPartySignup] = strconv.FormatBool(settings.ForceEmailOnThirdPartySignup)
+
+ if err := s.settingRepo.SetMultiple(ctx, updates); err != nil {
+ return fmt.Errorf("update auth source default settings: %w", err)
+ }
+ return nil
+}
+
// InitializeDefaultSettings 初始化默认设置
func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// 检查是否已有设置
@@ -933,25 +1074,46 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// 初始化默认设置
defaults := map[string]string{
- SettingKeyRegistrationEnabled: "true",
- SettingKeyEmailVerifyEnabled: "false",
- SettingKeyRegistrationEmailSuffixWhitelist: "[]",
- SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能
- SettingKeySiteName: "Sub2API",
- SettingKeySiteLogo: "",
- SettingKeyPurchaseSubscriptionEnabled: "false",
- SettingKeyPurchaseSubscriptionURL: "",
- SettingKeyTableDefaultPageSize: "20",
- SettingKeyTablePageSizeOptions: "[10,20,50,100]",
- SettingKeyCustomMenuItems: "[]",
- SettingKeyCustomEndpoints: "[]",
- SettingKeyOIDCConnectEnabled: "false",
- SettingKeyOIDCConnectProviderName: "OIDC",
- SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
- SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
- SettingKeyDefaultSubscriptions: "[]",
- SettingKeySMTPPort: "587",
- SettingKeySMTPUseTLS: "false",
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyEmailVerifyEnabled: "false",
+ SettingKeyRegistrationEmailSuffixWhitelist: "[]",
+ SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能
+ SettingKeySiteName: "Sub2API",
+ SettingKeySiteLogo: "",
+ SettingKeyPurchaseSubscriptionEnabled: "false",
+ SettingKeyPurchaseSubscriptionURL: "",
+ SettingKeyTableDefaultPageSize: "20",
+ SettingKeyTablePageSizeOptions: "[10,20,50,100]",
+ SettingKeyCustomMenuItems: "[]",
+ SettingKeyCustomEndpoints: "[]",
+ SettingKeyOIDCConnectEnabled: "false",
+ SettingKeyOIDCConnectProviderName: "OIDC",
+ SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
+ SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
+ SettingKeyDefaultSubscriptions: "[]",
+ SettingKeyAuthSourceDefaultEmailBalance: "0",
+ SettingKeyAuthSourceDefaultEmailConcurrency: "5",
+ SettingKeyAuthSourceDefaultEmailSubscriptions: "[]",
+ SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true",
+ SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "false",
+ SettingKeyAuthSourceDefaultLinuxDoBalance: "0",
+ SettingKeyAuthSourceDefaultLinuxDoConcurrency: "5",
+ SettingKeyAuthSourceDefaultLinuxDoSubscriptions: "[]",
+ SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "true",
+ SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind: "false",
+ SettingKeyAuthSourceDefaultOIDCBalance: "0",
+ SettingKeyAuthSourceDefaultOIDCConcurrency: "5",
+ SettingKeyAuthSourceDefaultOIDCSubscriptions: "[]",
+ SettingKeyAuthSourceDefaultOIDCGrantOnSignup: "true",
+ SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind: "false",
+ SettingKeyAuthSourceDefaultWeChatBalance: "0",
+ SettingKeyAuthSourceDefaultWeChatConcurrency: "5",
+ SettingKeyAuthSourceDefaultWeChatSubscriptions: "[]",
+ SettingKeyAuthSourceDefaultWeChatGrantOnSignup: "true",
+ SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind: "false",
+ SettingKeyForceEmailOnThirdPartySignup: "false",
+ SettingKeySMTPPort: "587",
+ SettingKeySMTPUseTLS: "false",
// Model fallback defaults
SettingKeyEnableModelFallback: "false",
SettingKeyFallbackModelAnthropic: "claude-3-5-sonnet-20241022",
@@ -1164,6 +1326,8 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
} else {
result.OIDCConnectValidateIDToken = oidcBase.ValidateIDToken
}
+ result.OIDCConnectUsePKCE = true
+ result.OIDCConnectValidateIDToken = true
if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(v)
} else {
@@ -1317,6 +1481,51 @@ func parseDefaultSubscriptions(raw string) []DefaultSubscriptionSetting {
return normalized
}
+func parseProviderDefaultGrantSettings(settings map[string]string, keys authSourceDefaultKeySet) ProviderDefaultGrantSettings {
+ result := ProviderDefaultGrantSettings{
+ Balance: defaultAuthSourceBalance,
+ Concurrency: defaultAuthSourceConcurrency,
+ Subscriptions: []DefaultSubscriptionSetting{},
+ GrantOnSignup: true,
+ GrantOnFirstBind: false,
+ }
+
+ if v, err := strconv.ParseFloat(strings.TrimSpace(settings[keys.balance]), 64); err == nil {
+ result.Balance = v
+ }
+ if v, err := strconv.Atoi(strings.TrimSpace(settings[keys.concurrency])); err == nil {
+ result.Concurrency = v
+ }
+ if items := parseDefaultSubscriptions(settings[keys.subscriptions]); items != nil {
+ result.Subscriptions = items
+ }
+ if raw, ok := settings[keys.grantOnSignup]; ok {
+ result.GrantOnSignup = raw == "true"
+ }
+ if raw, ok := settings[keys.grantOnFirstBind]; ok {
+ result.GrantOnFirstBind = raw == "true"
+ }
+
+ return result
+}
+
+func writeProviderDefaultGrantUpdates(updates map[string]string, keys authSourceDefaultKeySet, settings ProviderDefaultGrantSettings) {
+ updates[keys.balance] = strconv.FormatFloat(settings.Balance, 'f', 8, 64)
+ updates[keys.concurrency] = strconv.Itoa(settings.Concurrency)
+
+ subscriptions := settings.Subscriptions
+ if subscriptions == nil {
+ subscriptions = []DefaultSubscriptionSetting{}
+ }
+ raw, err := json.Marshal(subscriptions)
+ if err != nil {
+ raw = []byte("[]")
+ }
+ updates[keys.subscriptions] = string(raw)
+ updates[keys.grantOnSignup] = strconv.FormatBool(settings.GrantOnSignup)
+ updates[keys.grantOnFirstBind] = strconv.FormatBool(settings.GrantOnFirstBind)
+}
+
func parseTablePreferences(defaultPageSizeRaw, optionsRaw string) (int, []int) {
defaultPageSize := 20
if v, err := strconv.Atoi(strings.TrimSpace(defaultPageSizeRaw)); err == nil {
@@ -1539,6 +1748,7 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf
if v, ok := settings[SettingKeyLinuxDoConnectRedirectURL]; ok && strings.TrimSpace(v) != "" {
effective.RedirectURL = strings.TrimSpace(v)
}
+ effective.UsePKCE = true
if !effective.Enabled {
return config.LinuxDoConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled")
@@ -1587,9 +1797,6 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client secret not configured")
}
case "none":
- if !effective.UsePKCE {
- return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth pkce must be enabled when token_auth_method=none")
- }
default:
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token_auth_method invalid")
}
@@ -1737,6 +1944,8 @@ func (s *SettingService) GetOIDCConnectOAuthConfig(ctx context.Context) (config.
if raw, ok := settings[SettingKeyOIDCConnectValidateIDToken]; ok {
effective.ValidateIDToken = raw == "true"
}
+ effective.UsePKCE = true
+ effective.ValidateIDToken = true
if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" {
effective.AllowedSigningAlgs = strings.TrimSpace(v)
}
@@ -1864,9 +2073,6 @@ func (s *SettingService) GetOIDCConnectOAuthConfig(ctx context.Context) (config.
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client secret not configured")
}
case "none":
- if !effective.UsePKCE {
- return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth pkce must be enabled when token_auth_method=none")
- }
default:
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token_auth_method invalid")
}
diff --git a/backend/internal/service/setting_service_auth_source_defaults_test.go b/backend/internal/service/setting_service_auth_source_defaults_test.go
new file mode 100644
index 00000000..097bf604
--- /dev/null
+++ b/backend/internal/service/setting_service_auth_source_defaults_test.go
@@ -0,0 +1,136 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "encoding/json"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/stretchr/testify/require"
+)
+
+type authSourceDefaultsRepoStub struct {
+ values map[string]string
+ updates map[string]string
+}
+
+func (s *authSourceDefaultsRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
+ panic("unexpected Get call")
+}
+
+func (s *authSourceDefaultsRepoStub) GetValue(ctx context.Context, key string) (string, error) {
+ panic("unexpected GetValue call")
+}
+
+func (s *authSourceDefaultsRepoStub) Set(ctx context.Context, key, value string) error {
+ panic("unexpected Set call")
+}
+
+func (s *authSourceDefaultsRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
+ out := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if value, ok := s.values[key]; ok {
+ out[key] = value
+ }
+ }
+ return out, nil
+}
+
+func (s *authSourceDefaultsRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
+ s.updates = make(map[string]string, len(settings))
+ for key, value := range settings {
+ s.updates[key] = value
+ if s.values == nil {
+ s.values = map[string]string{}
+ }
+ s.values[key] = value
+ }
+ return nil
+}
+
+func (s *authSourceDefaultsRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
+ panic("unexpected GetAll call")
+}
+
+func (s *authSourceDefaultsRepoStub) Delete(ctx context.Context, key string) error {
+ panic("unexpected Delete call")
+}
+
+func TestSettingService_GetAuthSourceDefaultSettings_ParsesValuesAndDefaults(t *testing.T) {
+ repo := &authSourceDefaultsRepoStub{
+ values: map[string]string{
+ SettingKeyAuthSourceDefaultEmailBalance: "12.5",
+ SettingKeyAuthSourceDefaultEmailConcurrency: "7",
+ SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
+ SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false",
+ SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind: "true",
+ SettingKeyForceEmailOnThirdPartySignup: "true",
+ },
+ }
+ svc := NewSettingService(repo, &config.Config{})
+
+ got, err := svc.GetAuthSourceDefaultSettings(context.Background())
+ require.NoError(t, err)
+ require.Equal(t, 12.5, got.Email.Balance)
+ require.Equal(t, 7, got.Email.Concurrency)
+ require.Equal(t, []DefaultSubscriptionSetting{{GroupID: 11, ValidityDays: 30}}, got.Email.Subscriptions)
+ require.False(t, got.Email.GrantOnSignup)
+ require.False(t, got.Email.GrantOnFirstBind)
+ require.Equal(t, 0.0, got.LinuxDo.Balance)
+ require.Equal(t, 5, got.LinuxDo.Concurrency)
+ require.Equal(t, []DefaultSubscriptionSetting{}, got.LinuxDo.Subscriptions)
+ require.True(t, got.LinuxDo.GrantOnSignup)
+ require.True(t, got.LinuxDo.GrantOnFirstBind)
+ require.Equal(t, 5, got.OIDC.Concurrency)
+ require.Equal(t, 5, got.WeChat.Concurrency)
+ require.True(t, got.ForceEmailOnThirdPartySignup)
+}
+
+func TestSettingService_UpdateAuthSourceDefaultSettings_PersistsAllKeys(t *testing.T) {
+ repo := &authSourceDefaultsRepoStub{}
+ svc := NewSettingService(repo, &config.Config{})
+
+ err := svc.UpdateAuthSourceDefaultSettings(context.Background(), &AuthSourceDefaultSettings{
+ Email: ProviderDefaultGrantSettings{
+ Balance: 1.25,
+ Concurrency: 3,
+ Subscriptions: []DefaultSubscriptionSetting{{GroupID: 21, ValidityDays: 14}},
+ GrantOnSignup: false,
+ GrantOnFirstBind: true,
+ },
+ LinuxDo: ProviderDefaultGrantSettings{
+ Balance: 2,
+ Concurrency: 4,
+ Subscriptions: []DefaultSubscriptionSetting{{GroupID: 22, ValidityDays: 30}},
+ GrantOnSignup: true,
+ GrantOnFirstBind: false,
+ },
+ OIDC: ProviderDefaultGrantSettings{
+ Balance: 3,
+ Concurrency: 5,
+ Subscriptions: []DefaultSubscriptionSetting{{GroupID: 23, ValidityDays: 60}},
+ GrantOnSignup: true,
+ GrantOnFirstBind: true,
+ },
+ WeChat: ProviderDefaultGrantSettings{
+ Balance: 4,
+ Concurrency: 6,
+ Subscriptions: []DefaultSubscriptionSetting{{GroupID: 24, ValidityDays: 90}},
+ GrantOnSignup: false,
+ GrantOnFirstBind: false,
+ },
+ ForceEmailOnThirdPartySignup: true,
+ })
+ require.NoError(t, err)
+ require.Equal(t, "1.25000000", repo.updates[SettingKeyAuthSourceDefaultEmailBalance])
+ require.Equal(t, "3", repo.updates[SettingKeyAuthSourceDefaultEmailConcurrency])
+ require.Equal(t, "false", repo.updates[SettingKeyAuthSourceDefaultEmailGrantOnSignup])
+ require.Equal(t, "true", repo.updates[SettingKeyAuthSourceDefaultEmailGrantOnFirstBind])
+ require.Equal(t, "true", repo.updates[SettingKeyForceEmailOnThirdPartySignup])
+
+ var got []DefaultSubscriptionSetting
+ require.NoError(t, json.Unmarshal([]byte(repo.updates[SettingKeyAuthSourceDefaultWeChatSubscriptions]), &got))
+ require.Equal(t, []DefaultSubscriptionSetting{{GroupID: 24, ValidityDays: 90}}, got)
+}
diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go
index ab2eb274..e991ebef 100644
--- a/backend/internal/service/settings_view.go
+++ b/backend/internal/service/settings_view.go
@@ -152,6 +152,7 @@ type PublicSettings struct {
CustomEndpoints string // JSON array of custom endpoints
LinuxDoOAuthEnabled bool
+ WeChatOAuthEnabled bool
BackendModeEnabled bool
PaymentEnabled bool
OIDCOAuthEnabled bool
diff --git a/backend/internal/service/user.go b/backend/internal/service/user.go
index 59f8aa6b..d8b5325c 100644
--- a/backend/internal/service/user.go
+++ b/backend/internal/service/user.go
@@ -7,19 +7,27 @@ import (
)
type User struct {
- ID int64
- Email string
- Username string
- Notes string
- PasswordHash string
- Role string
- Balance float64
- Concurrency int
- Status string
- AllowedGroups []int64
- TokenVersion int64 // Incremented on password change to invalidate existing tokens
- CreatedAt time.Time
- UpdatedAt time.Time
+ ID int64
+ Email string
+ Username string
+ Notes string
+ AvatarURL string
+ AvatarSource string
+ AvatarMIME string
+ AvatarByteSize int
+ AvatarSHA256 string
+ PasswordHash string
+ Role string
+ Balance float64
+ Concurrency int
+ Status string
+ AllowedGroups []int64
+ TokenVersion int64 // Incremented on password change to invalidate existing tokens
+ SignupSource string
+ LastLoginAt *time.Time
+ LastActiveAt *time.Time
+ CreatedAt time.Time
+ UpdatedAt time.Time
// GroupRates 用户专属分组倍率配置
// map[groupID]rateMultiplier
diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go
index 3490e804..2f6d9427 100644
--- a/backend/internal/service/user_service.go
+++ b/backend/internal/service/user_service.go
@@ -2,9 +2,13 @@ package service
import (
"context"
+ "crypto/sha256"
"crypto/subtle"
+ "encoding/base64"
+ "encoding/hex"
"fmt"
"log/slog"
+ "net/url"
"strings"
"time"
@@ -17,10 +21,14 @@ var (
ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect")
ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions")
ErrNotifyCodeUserRateLimit = infraerrors.TooManyRequests("NOTIFY_CODE_USER_RATE_LIMIT", "too many verification codes requested, please try again later")
+ ErrAvatarInvalid = infraerrors.BadRequest("AVATAR_INVALID", "avatar must be a valid image data URL or http(s) URL")
+ ErrAvatarTooLarge = infraerrors.BadRequest("AVATAR_TOO_LARGE", "avatar image must be 100KB or smaller")
+ ErrAvatarNotImage = infraerrors.BadRequest("AVATAR_NOT_IMAGE", "avatar content must be an image")
)
const (
- maxNotifyEmails = 3 // Maximum number of notification emails per user
+ maxNotifyEmails = 3 // Maximum number of notification emails per user
+ maxInlineAvatarBytes = 100 * 1024
// User-level rate limiting for notify email verification codes
notifyCodeUserRateLimit = 5
@@ -47,6 +55,9 @@ type UserRepository interface {
GetFirstAdmin(ctx context.Context) (*User, error)
Update(ctx context.Context, user *User) error
Delete(ctx context.Context, id int64) error
+ GetUserAvatar(ctx context.Context, userID int64) (*UserAvatar, error)
+ UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error)
+ DeleteUserAvatar(ctx context.Context, userID int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UserListFilters) ([]User, *pagination.PaginationResult, error)
@@ -71,11 +82,30 @@ type UserRepository interface {
type UpdateProfileRequest struct {
Email *string `json:"email"`
Username *string `json:"username"`
+ AvatarURL *string `json:"avatar_url"`
Concurrency *int `json:"concurrency"`
BalanceNotifyEnabled *bool `json:"balance_notify_enabled"`
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"`
}
+type UserAvatar struct {
+ StorageProvider string
+ StorageKey string
+ URL string
+ ContentType string
+ ByteSize int
+ SHA256 string
+}
+
+type UpsertUserAvatarInput struct {
+ StorageProvider string
+ StorageKey string
+ URL string
+ ContentType string
+ ByteSize int
+ SHA256 string
+}
+
// ChangePasswordRequest 修改密码请求
type ChangePasswordRequest struct {
CurrentPassword string `json:"current_password"`
@@ -115,6 +145,9 @@ func (s *UserService) GetProfile(ctx context.Context, userID int64) (*User, erro
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
+ if err := s.hydrateUserAvatar(ctx, user); err != nil {
+ return nil, fmt.Errorf("get user avatar: %w", err)
+ }
return user, nil
}
@@ -143,6 +176,27 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
user.Username = *req.Username
}
+ if req.AvatarURL != nil {
+ avatarValue := strings.TrimSpace(*req.AvatarURL)
+ switch {
+ case avatarValue == "":
+ if err := s.userRepo.DeleteUserAvatar(ctx, userID); err != nil {
+ return nil, fmt.Errorf("delete avatar: %w", err)
+ }
+ applyUserAvatar(user, nil)
+ default:
+ avatarInput, err := normalizeUserAvatarInput(avatarValue)
+ if err != nil {
+ return nil, err
+ }
+ avatar, err := s.userRepo.UpsertUserAvatar(ctx, userID, avatarInput)
+ if err != nil {
+ return nil, fmt.Errorf("upsert avatar: %w", err)
+ }
+ applyUserAvatar(user, avatar)
+ }
+ }
+
if req.Concurrency != nil {
user.Concurrency = *req.Concurrency
}
@@ -168,6 +222,87 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
return user, nil
}
+func applyUserAvatar(user *User, avatar *UserAvatar) {
+ if user == nil {
+ return
+ }
+ if avatar == nil {
+ user.AvatarURL = ""
+ user.AvatarSource = ""
+ user.AvatarMIME = ""
+ user.AvatarByteSize = 0
+ user.AvatarSHA256 = ""
+ return
+ }
+
+ user.AvatarURL = avatar.URL
+ user.AvatarSource = avatar.StorageProvider
+ user.AvatarMIME = avatar.ContentType
+ user.AvatarByteSize = avatar.ByteSize
+ user.AvatarSHA256 = avatar.SHA256
+}
+
+func normalizeUserAvatarInput(raw string) (UpsertUserAvatarInput, error) {
+ raw = strings.TrimSpace(raw)
+ if raw == "" {
+ return UpsertUserAvatarInput{}, ErrAvatarInvalid
+ }
+ if strings.HasPrefix(raw, "data:") {
+ return normalizeInlineUserAvatarInput(raw)
+ }
+
+ parsed, err := url.Parse(raw)
+ if err != nil || parsed == nil {
+ return UpsertUserAvatarInput{}, ErrAvatarInvalid
+ }
+ if !strings.EqualFold(parsed.Scheme, "http") && !strings.EqualFold(parsed.Scheme, "https") {
+ return UpsertUserAvatarInput{}, ErrAvatarInvalid
+ }
+ if strings.TrimSpace(parsed.Host) == "" {
+ return UpsertUserAvatarInput{}, ErrAvatarInvalid
+ }
+
+ return UpsertUserAvatarInput{
+ StorageProvider: "remote_url",
+ URL: raw,
+ }, nil
+}
+
+func normalizeInlineUserAvatarInput(raw string) (UpsertUserAvatarInput, error) {
+ body := strings.TrimPrefix(raw, "data:")
+ meta, encoded, ok := strings.Cut(body, ",")
+ if !ok {
+ return UpsertUserAvatarInput{}, ErrAvatarInvalid
+ }
+ meta = strings.TrimSpace(meta)
+ encoded = strings.TrimSpace(encoded)
+ if !strings.HasSuffix(strings.ToLower(meta), ";base64") {
+ return UpsertUserAvatarInput{}, ErrAvatarInvalid
+ }
+
+ contentType := strings.TrimSpace(meta[:len(meta)-len(";base64")])
+ if contentType == "" || !strings.HasPrefix(strings.ToLower(contentType), "image/") {
+ return UpsertUserAvatarInput{}, ErrAvatarNotImage
+ }
+
+ decoded, err := base64.StdEncoding.DecodeString(encoded)
+ if err != nil {
+ return UpsertUserAvatarInput{}, ErrAvatarInvalid
+ }
+ if len(decoded) > maxInlineAvatarBytes {
+ return UpsertUserAvatarInput{}, ErrAvatarTooLarge
+ }
+
+ sum := sha256.Sum256(decoded)
+ return UpsertUserAvatarInput{
+ StorageProvider: "inline",
+ URL: raw,
+ ContentType: contentType,
+ ByteSize: len(decoded),
+ SHA256: hex.EncodeToString(sum[:]),
+ }, nil
+}
+
// ChangePassword 修改密码
// Security: Increments TokenVersion to invalidate all existing JWT tokens
func (s *UserService) ChangePassword(ctx context.Context, userID int64, req ChangePasswordRequest) error {
@@ -202,9 +337,25 @@ func (s *UserService) GetByID(ctx context.Context, id int64) (*User, error) {
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
+ if err := s.hydrateUserAvatar(ctx, user); err != nil {
+ return nil, fmt.Errorf("get user avatar: %w", err)
+ }
return user, nil
}
+func (s *UserService) hydrateUserAvatar(ctx context.Context, user *User) error {
+ if s == nil || s.userRepo == nil || user == nil || user.ID == 0 {
+ return nil
+ }
+
+ avatar, err := s.userRepo.GetUserAvatar(ctx, user.ID)
+ if err != nil {
+ return err
+ }
+ applyUserAvatar(user, avatar)
+ return nil
+}
+
// List 获取用户列表(管理员功能)
func (s *UserService) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
users, pagination, err := s.userRepo.List(ctx, params)
diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go
index a998d5f4..7d63bb36 100644
--- a/backend/internal/service/user_service_test.go
+++ b/backend/internal/service/user_service_test.go
@@ -4,6 +4,9 @@ package service
import (
"context"
+ "crypto/sha256"
+ "encoding/base64"
+ "encoding/hex"
"errors"
"sync"
"sync/atomic"
@@ -19,14 +22,65 @@ import (
type mockUserRepo struct {
updateBalanceErr error
updateBalanceFn func(ctx context.Context, id int64, amount float64) error
+ getByIDUser *User
+ getByIDErr error
+ updateFn func(ctx context.Context, user *User) error
+ updateCalls int
+ upsertAvatarFn func(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error)
+ upsertAvatarArgs []UpsertUserAvatarInput
+ deleteAvatarFn func(ctx context.Context, userID int64) error
+ deleteAvatarIDs []int64
+ getAvatarFn func(ctx context.Context, userID int64) (*UserAvatar, error)
}
-func (m *mockUserRepo) Create(context.Context, *User) error { return nil }
-func (m *mockUserRepo) GetByID(context.Context, int64) (*User, error) { return &User{}, nil }
+func (m *mockUserRepo) Create(context.Context, *User) error { return nil }
+func (m *mockUserRepo) GetByID(context.Context, int64) (*User, error) {
+ if m.getByIDErr != nil {
+ return nil, m.getByIDErr
+ }
+ if m.getByIDUser != nil {
+ cloned := *m.getByIDUser
+ return &cloned, nil
+ }
+ return &User{}, nil
+}
func (m *mockUserRepo) GetByEmail(context.Context, string) (*User, error) { return &User{}, nil }
func (m *mockUserRepo) GetFirstAdmin(context.Context) (*User, error) { return &User{}, nil }
-func (m *mockUserRepo) Update(context.Context, *User) error { return nil }
-func (m *mockUserRepo) Delete(context.Context, int64) error { return nil }
+func (m *mockUserRepo) Update(ctx context.Context, user *User) error {
+ m.updateCalls++
+ if m.updateFn != nil {
+ return m.updateFn(ctx, user)
+ }
+ return nil
+}
+func (m *mockUserRepo) Delete(context.Context, int64) error { return nil }
+func (m *mockUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*UserAvatar, error) {
+ if m.getAvatarFn != nil {
+ return m.getAvatarFn(ctx, userID)
+ }
+ return nil, nil
+}
+func (m *mockUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error) {
+ m.upsertAvatarArgs = append(m.upsertAvatarArgs, input)
+ if m.upsertAvatarFn != nil {
+ return m.upsertAvatarFn(ctx, userID, input)
+ }
+ return &UserAvatar{
+ StorageProvider: input.StorageProvider,
+ StorageKey: input.StorageKey,
+ URL: input.URL,
+ ContentType: input.ContentType,
+ ByteSize: input.ByteSize,
+ SHA256: input.SHA256,
+ }, nil
+}
+func (m *mockUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error {
+ m.deleteAvatarIDs = append(m.deleteAvatarIDs, userID)
+ if m.deleteAvatarFn != nil {
+ return m.deleteAvatarFn(ctx, userID)
+ }
+ return nil
+}
func (m *mockUserRepo) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
return nil, nil, nil
}
@@ -200,3 +254,121 @@ func TestNewUserService_FieldsAssignment(t *testing.T) {
require.Equal(t, auth, svc.authCacheInvalidator)
require.Equal(t, cache, svc.billingCache)
}
+
+func TestUpdateProfile_StoresInlineAvatarWithinLimit(t *testing.T) {
+ raw := []byte("small-avatar")
+ dataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(raw)
+ expectedSum := sha256.Sum256(raw)
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 7,
+ Email: "avatar@example.com",
+ Username: "avatar-user",
+ },
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ updated, err := svc.UpdateProfile(context.Background(), 7, UpdateProfileRequest{
+ AvatarURL: &dataURL,
+ })
+ require.NoError(t, err)
+ require.Len(t, repo.upsertAvatarArgs, 1)
+ require.Equal(t, "inline", repo.upsertAvatarArgs[0].StorageProvider)
+ require.Equal(t, "image/png", repo.upsertAvatarArgs[0].ContentType)
+ require.Equal(t, len(raw), repo.upsertAvatarArgs[0].ByteSize)
+ require.Equal(t, hex.EncodeToString(expectedSum[:]), repo.upsertAvatarArgs[0].SHA256)
+ require.Equal(t, dataURL, updated.AvatarURL)
+ require.Equal(t, "inline", updated.AvatarSource)
+ require.Equal(t, "image/png", updated.AvatarMIME)
+ require.Equal(t, len(raw), updated.AvatarByteSize)
+ require.Equal(t, hex.EncodeToString(expectedSum[:]), updated.AvatarSHA256)
+}
+
+func TestUpdateProfile_RejectsInlineAvatarOverLimit(t *testing.T) {
+ raw := make([]byte, maxInlineAvatarBytes+1)
+ dataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(raw)
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 8,
+ Email: "large-avatar@example.com",
+ Username: "too-large",
+ },
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ _, err := svc.UpdateProfile(context.Background(), 8, UpdateProfileRequest{
+ AvatarURL: &dataURL,
+ })
+ require.ErrorIs(t, err, ErrAvatarTooLarge)
+ require.Empty(t, repo.upsertAvatarArgs)
+ require.Empty(t, repo.deleteAvatarIDs)
+ require.Zero(t, repo.updateCalls)
+}
+
+func TestUpdateProfile_StoresRemoteAvatarURL(t *testing.T) {
+ remoteURL := "https://cdn.example.com/avatar.png"
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 9,
+ Email: "remote-avatar@example.com",
+ Username: "remote-avatar",
+ },
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ updated, err := svc.UpdateProfile(context.Background(), 9, UpdateProfileRequest{
+ AvatarURL: &remoteURL,
+ })
+ require.NoError(t, err)
+ require.Len(t, repo.upsertAvatarArgs, 1)
+ require.Equal(t, "remote_url", repo.upsertAvatarArgs[0].StorageProvider)
+ require.Equal(t, remoteURL, repo.upsertAvatarArgs[0].URL)
+ require.Equal(t, remoteURL, updated.AvatarURL)
+ require.Equal(t, "remote_url", updated.AvatarSource)
+ require.Zero(t, updated.AvatarByteSize)
+}
+
+func TestUpdateProfile_DeletesAvatarOnEmptyString(t *testing.T) {
+ empty := ""
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 10,
+ Email: "delete-avatar@example.com",
+ Username: "delete-avatar",
+ AvatarURL: "https://cdn.example.com/old.png",
+ AvatarSource: "remote_url",
+ },
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ updated, err := svc.UpdateProfile(context.Background(), 10, UpdateProfileRequest{
+ AvatarURL: &empty,
+ })
+ require.NoError(t, err)
+ require.Equal(t, []int64{10}, repo.deleteAvatarIDs)
+ require.Empty(t, repo.upsertAvatarArgs)
+ require.Empty(t, updated.AvatarURL)
+ require.Empty(t, updated.AvatarSource)
+}
+
+func TestGetProfile_HydratesAvatarFromRepository(t *testing.T) {
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 12,
+ Email: "profile-avatar@example.com",
+ Username: "profile-avatar",
+ },
+ getAvatarFn: func(context.Context, int64) (*UserAvatar, error) {
+ return &UserAvatar{
+ StorageProvider: "remote_url",
+ URL: "https://cdn.example.com/profile.png",
+ }, nil
+ },
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ user, err := svc.GetProfile(context.Background(), 12)
+ require.NoError(t, err)
+ require.Equal(t, "https://cdn.example.com/profile.png", user.AvatarURL)
+ require.Equal(t, "remote_url", user.AvatarSource)
+}
diff --git a/backend/migrations/108_auth_identity_foundation_core.sql b/backend/migrations/108_auth_identity_foundation_core.sql
new file mode 100644
index 00000000..117e3ca3
--- /dev/null
+++ b/backend/migrations/108_auth_identity_foundation_core.sql
@@ -0,0 +1,141 @@
+ALTER TABLE users
+ADD COLUMN IF NOT EXISTS signup_source VARCHAR(20) NOT NULL DEFAULT 'email',
+ADD COLUMN IF NOT EXISTS last_login_at TIMESTAMPTZ NULL,
+ADD COLUMN IF NOT EXISTS last_active_at TIMESTAMPTZ NULL;
+
+UPDATE users
+SET signup_source = 'email'
+WHERE signup_source IS NULL OR signup_source = '';
+
+DO $$
+BEGIN
+ IF NOT EXISTS (
+ SELECT 1
+ FROM pg_constraint
+ WHERE conname = 'users_signup_source_check'
+ ) THEN
+ ALTER TABLE users
+ ADD CONSTRAINT users_signup_source_check
+ CHECK (signup_source IN ('email', 'linuxdo', 'wechat', 'oidc'));
+ END IF;
+END $$;
+
+CREATE TABLE IF NOT EXISTS auth_identities (
+ id BIGSERIAL PRIMARY KEY,
+ user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
+ provider_type VARCHAR(20) NOT NULL,
+ provider_key TEXT NOT NULL,
+ provider_subject TEXT NOT NULL,
+ verified_at TIMESTAMPTZ NULL,
+ issuer TEXT NULL,
+ metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ CONSTRAINT auth_identities_provider_type_check
+ CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc'))
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS auth_identities_provider_subject_key
+ ON auth_identities (provider_type, provider_key, provider_subject);
+
+CREATE INDEX IF NOT EXISTS auth_identities_user_id_idx
+ ON auth_identities (user_id);
+
+CREATE INDEX IF NOT EXISTS auth_identities_user_provider_idx
+ ON auth_identities (user_id, provider_type);
+
+CREATE TABLE IF NOT EXISTS auth_identity_channels (
+ id BIGSERIAL PRIMARY KEY,
+ identity_id BIGINT NOT NULL REFERENCES auth_identities(id) ON DELETE CASCADE,
+ provider_type VARCHAR(20) NOT NULL,
+ provider_key TEXT NOT NULL,
+ channel VARCHAR(20) NOT NULL,
+ channel_app_id TEXT NOT NULL,
+ channel_subject TEXT NOT NULL,
+ metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ CONSTRAINT auth_identity_channels_provider_type_check
+ CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc'))
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS auth_identity_channels_channel_key
+ ON auth_identity_channels (provider_type, provider_key, channel, channel_app_id, channel_subject);
+
+CREATE INDEX IF NOT EXISTS auth_identity_channels_identity_id_idx
+ ON auth_identity_channels (identity_id);
+
+CREATE TABLE IF NOT EXISTS pending_auth_sessions (
+ id BIGSERIAL PRIMARY KEY,
+ session_token VARCHAR(255) NOT NULL,
+ intent VARCHAR(40) NOT NULL,
+ provider_type VARCHAR(20) NOT NULL,
+ provider_key TEXT NOT NULL,
+ provider_subject TEXT NOT NULL,
+ target_user_id BIGINT NULL REFERENCES users(id) ON DELETE SET NULL,
+ redirect_to TEXT NOT NULL DEFAULT '',
+ resolved_email TEXT NOT NULL DEFAULT '',
+ registration_password_hash TEXT NOT NULL DEFAULT '',
+ upstream_identity_claims JSONB NOT NULL DEFAULT '{}'::jsonb,
+ local_flow_state JSONB NOT NULL DEFAULT '{}'::jsonb,
+ browser_session_key TEXT NOT NULL DEFAULT '',
+ completion_code_hash TEXT NOT NULL DEFAULT '',
+ completion_code_expires_at TIMESTAMPTZ NULL,
+ email_verified_at TIMESTAMPTZ NULL,
+ password_verified_at TIMESTAMPTZ NULL,
+ totp_verified_at TIMESTAMPTZ NULL,
+ expires_at TIMESTAMPTZ NOT NULL,
+ consumed_at TIMESTAMPTZ NULL,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ CONSTRAINT pending_auth_sessions_intent_check
+ CHECK (intent IN ('login', 'bind_current_user', 'adopt_existing_user_by_email')),
+ CONSTRAINT pending_auth_sessions_provider_type_check
+ CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc'))
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS pending_auth_sessions_session_token_key
+ ON pending_auth_sessions (session_token);
+
+CREATE INDEX IF NOT EXISTS pending_auth_sessions_target_user_id_idx
+ ON pending_auth_sessions (target_user_id);
+
+CREATE INDEX IF NOT EXISTS pending_auth_sessions_expires_at_idx
+ ON pending_auth_sessions (expires_at);
+
+CREATE INDEX IF NOT EXISTS pending_auth_sessions_provider_idx
+ ON pending_auth_sessions (provider_type, provider_key, provider_subject);
+
+CREATE INDEX IF NOT EXISTS pending_auth_sessions_completion_code_idx
+ ON pending_auth_sessions (completion_code_hash);
+
+CREATE TABLE IF NOT EXISTS identity_adoption_decisions (
+ id BIGSERIAL PRIMARY KEY,
+ pending_auth_session_id BIGINT NOT NULL REFERENCES pending_auth_sessions(id) ON DELETE CASCADE,
+ identity_id BIGINT NULL REFERENCES auth_identities(id) ON DELETE SET NULL,
+ adopt_display_name BOOLEAN NOT NULL DEFAULT FALSE,
+ adopt_avatar BOOLEAN NOT NULL DEFAULT FALSE,
+ decided_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS identity_adoption_decisions_pending_auth_session_id_key
+ ON identity_adoption_decisions (pending_auth_session_id);
+
+CREATE INDEX IF NOT EXISTS identity_adoption_decisions_identity_id_idx
+ ON identity_adoption_decisions (identity_id);
+
+CREATE TABLE IF NOT EXISTS auth_identity_migration_reports (
+ id BIGSERIAL PRIMARY KEY,
+ report_type VARCHAR(40) NOT NULL,
+ report_key TEXT NOT NULL,
+ details JSONB NOT NULL DEFAULT '{}'::jsonb,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
+);
+
+CREATE INDEX IF NOT EXISTS auth_identity_migration_reports_type_idx
+ ON auth_identity_migration_reports (report_type);
+
+CREATE UNIQUE INDEX IF NOT EXISTS auth_identity_migration_reports_type_key
+ ON auth_identity_migration_reports (report_type, report_key);
diff --git a/backend/migrations/109_auth_identity_compat_backfill.sql b/backend/migrations/109_auth_identity_compat_backfill.sql
new file mode 100644
index 00000000..ddbbedbc
--- /dev/null
+++ b/backend/migrations/109_auth_identity_compat_backfill.sql
@@ -0,0 +1,125 @@
+INSERT INTO auth_identities (
+ user_id,
+ provider_type,
+ provider_key,
+ provider_subject,
+ verified_at,
+ metadata
+)
+SELECT
+ u.id,
+ 'email',
+ 'email',
+ LOWER(BTRIM(u.email)),
+ COALESCE(u.updated_at, u.created_at, NOW()),
+ jsonb_build_object(
+ 'backfill_source', 'users.email',
+ 'migration', '109_auth_identity_compat_backfill'
+ )
+FROM users AS u
+WHERE u.deleted_at IS NULL
+ AND BTRIM(COALESCE(u.email, '')) <> ''
+ AND RIGHT(LOWER(BTRIM(u.email)), LENGTH('@linuxdo-connect.invalid')) <> '@linuxdo-connect.invalid'
+ AND RIGHT(LOWER(BTRIM(u.email)), LENGTH('@oidc-connect.invalid')) <> '@oidc-connect.invalid'
+ AND RIGHT(LOWER(BTRIM(u.email)), LENGTH('@wechat-connect.invalid')) <> '@wechat-connect.invalid'
+ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
+
+INSERT INTO auth_identities (
+ user_id,
+ provider_type,
+ provider_key,
+ provider_subject,
+ verified_at,
+ metadata
+)
+SELECT
+ u.id,
+ 'linuxdo',
+ 'linuxdo',
+ SUBSTRING(BTRIM(u.email) FROM '(?i)^linuxdo-(.+)@linuxdo-connect\.invalid$'),
+ COALESCE(u.updated_at, u.created_at, NOW()),
+ jsonb_build_object(
+ 'backfill_source', 'synthetic_email',
+ 'legacy_email', BTRIM(u.email),
+ 'migration', '109_auth_identity_compat_backfill'
+ )
+FROM users AS u
+WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(u.email)) ~ '^linuxdo-.+@linuxdo-connect\.invalid$'
+ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
+
+INSERT INTO auth_identities (
+ user_id,
+ provider_type,
+ provider_key,
+ provider_subject,
+ verified_at,
+ metadata
+)
+SELECT
+ u.id,
+ 'wechat',
+ 'wechat',
+ SUBSTRING(BTRIM(u.email) FROM '(?i)^wechat-(.+)@wechat-connect\.invalid$'),
+ COALESCE(u.updated_at, u.created_at, NOW()),
+ jsonb_build_object(
+ 'backfill_source', 'synthetic_email',
+ 'legacy_email', BTRIM(u.email),
+ 'migration', '109_auth_identity_compat_backfill'
+ )
+FROM users AS u
+WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(u.email)) ~ '^wechat-.+@wechat-connect\.invalid$'
+ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
+
+UPDATE users
+SET signup_source = 'linuxdo'
+WHERE deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(email, ''))) ~ '^linuxdo-.+@linuxdo-connect\.invalid$';
+
+UPDATE users
+SET signup_source = 'wechat'
+WHERE deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(email, ''))) ~ '^wechat-.+@wechat-connect\.invalid$';
+
+UPDATE users
+SET signup_source = 'oidc'
+WHERE deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(email, ''))) ~ '^oidc-.+@oidc-connect\.invalid$';
+
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'oidc_synthetic_email_requires_manual_recovery',
+ CAST(u.id AS TEXT),
+ jsonb_build_object(
+ 'user_id', u.id,
+ 'email', LOWER(BTRIM(u.email)),
+ 'reason', 'cannot recover issuer_plus_sub deterministically from synthetic email alone',
+ 'migration', '109_auth_identity_compat_backfill'
+ )
+FROM users AS u
+WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(u.email)) ~ '^oidc-.+@oidc-connect\.invalid$'
+ON CONFLICT (report_type, report_key) DO NOTHING;
+
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'wechat_openid_only_requires_remediation',
+ CAST(u.id AS TEXT),
+ jsonb_build_object(
+ 'user_id', u.id,
+ 'email', LOWER(BTRIM(u.email)),
+ 'reason', 'legacy wechat synthetic identity requires explicit unionid remediation if channel-only data exists',
+ 'migration', '109_auth_identity_compat_backfill'
+ )
+FROM users AS u
+WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(u.email)) ~ '^wechat-.+@wechat-connect\.invalid$'
+ AND NOT EXISTS (
+ SELECT 1
+ FROM auth_identities ai
+ WHERE ai.user_id = u.id
+ AND ai.provider_type = 'wechat'
+ AND ai.provider_key = 'wechat'
+ )
+ON CONFLICT (report_type, report_key) DO NOTHING;
diff --git a/backend/migrations/110_pending_auth_and_provider_default_grants.sql b/backend/migrations/110_pending_auth_and_provider_default_grants.sql
new file mode 100644
index 00000000..fbaed62e
--- /dev/null
+++ b/backend/migrations/110_pending_auth_and_provider_default_grants.sql
@@ -0,0 +1,60 @@
+CREATE TABLE IF NOT EXISTS user_provider_default_grants (
+ id BIGSERIAL PRIMARY KEY,
+ user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
+ provider_type VARCHAR(20) NOT NULL,
+ grant_reason VARCHAR(20) NOT NULL DEFAULT 'first_bind',
+ granted_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ CONSTRAINT user_provider_default_grants_provider_type_check
+ CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc')),
+ CONSTRAINT user_provider_default_grants_reason_check
+ CHECK (grant_reason IN ('signup', 'first_bind'))
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS user_provider_default_grants_user_provider_reason_key
+ ON user_provider_default_grants (user_id, provider_type, grant_reason);
+
+CREATE INDEX IF NOT EXISTS user_provider_default_grants_user_id_idx
+ ON user_provider_default_grants (user_id);
+
+CREATE TABLE IF NOT EXISTS user_avatars (
+ id BIGSERIAL PRIMARY KEY,
+ user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
+ storage_provider VARCHAR(20) NOT NULL DEFAULT 'database',
+ storage_key TEXT NOT NULL DEFAULT '',
+ url TEXT NOT NULL DEFAULT '',
+ content_type VARCHAR(100) NOT NULL DEFAULT '',
+ byte_size INT NOT NULL DEFAULT 0,
+ sha256 VARCHAR(64) NOT NULL DEFAULT '',
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS user_avatars_user_id_key
+ ON user_avatars (user_id);
+
+INSERT INTO settings (key, value)
+VALUES
+ ('auth_source_default_email_balance', '0'),
+ ('auth_source_default_email_concurrency', '5'),
+ ('auth_source_default_email_subscriptions', '[]'),
+ ('auth_source_default_email_grant_on_signup', 'true'),
+ ('auth_source_default_email_grant_on_first_bind', 'false'),
+ ('auth_source_default_linuxdo_balance', '0'),
+ ('auth_source_default_linuxdo_concurrency', '5'),
+ ('auth_source_default_linuxdo_subscriptions', '[]'),
+ ('auth_source_default_linuxdo_grant_on_signup', 'true'),
+ ('auth_source_default_linuxdo_grant_on_first_bind', 'false'),
+ ('auth_source_default_oidc_balance', '0'),
+ ('auth_source_default_oidc_concurrency', '5'),
+ ('auth_source_default_oidc_subscriptions', '[]'),
+ ('auth_source_default_oidc_grant_on_signup', 'true'),
+ ('auth_source_default_oidc_grant_on_first_bind', 'false'),
+ ('auth_source_default_wechat_balance', '0'),
+ ('auth_source_default_wechat_concurrency', '5'),
+ ('auth_source_default_wechat_subscriptions', '[]'),
+ ('auth_source_default_wechat_grant_on_signup', 'true'),
+ ('auth_source_default_wechat_grant_on_first_bind', 'false'),
+ ('force_email_on_third_party_signup', 'false')
+ON CONFLICT (key) DO NOTHING;
+
diff --git a/backend/migrations/111_payment_routing_and_scheduler_flags.sql b/backend/migrations/111_payment_routing_and_scheduler_flags.sql
new file mode 100644
index 00000000..f222a8d4
--- /dev/null
+++ b/backend/migrations/111_payment_routing_and_scheduler_flags.sql
@@ -0,0 +1,8 @@
+INSERT INTO settings (key, value)
+VALUES
+ ('payment_visible_method_alipay_source', ''),
+ ('payment_visible_method_wxpay_source', ''),
+ ('payment_visible_method_alipay_enabled', 'false'),
+ ('payment_visible_method_wxpay_enabled', 'false'),
+ ('openai_advanced_scheduler_enabled', 'false')
+ON CONFLICT (key) DO NOTHING;
diff --git a/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts b/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts
new file mode 100644
index 00000000..574e1e36
--- /dev/null
+++ b/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts
@@ -0,0 +1,60 @@
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+
+const post = vi.fn()
+
+vi.mock('@/api/client', () => ({
+ apiClient: {
+ post
+ }
+}))
+
+describe('oauth adoption auth api', () => {
+ beforeEach(() => {
+ post.mockReset()
+ post.mockResolvedValue({ data: {} })
+ })
+
+ it('posts adoption decisions when exchanging pending oauth completion', async () => {
+ const { exchangePendingOAuthCompletion } = await import('@/api/auth')
+
+ await exchangePendingOAuthCompletion({
+ adoptDisplayName: false,
+ adoptAvatar: true
+ })
+
+ expect(post).toHaveBeenCalledWith('/auth/oauth/pending/exchange', {
+ adopt_display_name: false,
+ adopt_avatar: true
+ })
+ })
+
+ it('posts linuxdo invitation completion with adoption decisions', async () => {
+ const { completeLinuxDoOAuthRegistration } = await import('@/api/auth')
+
+ await completeLinuxDoOAuthRegistration('invite-code', {
+ adoptDisplayName: true,
+ adoptAvatar: false
+ })
+
+ expect(post).toHaveBeenCalledWith('/auth/oauth/linuxdo/complete-registration', {
+ invitation_code: 'invite-code',
+ adopt_display_name: true,
+ adopt_avatar: false
+ })
+ })
+
+ it('posts oidc invitation completion with adoption decisions', async () => {
+ const { completeOIDCOAuthRegistration } = await import('@/api/auth')
+
+ await completeOIDCOAuthRegistration('invite-code', {
+ adoptDisplayName: false,
+ adoptAvatar: true
+ })
+
+ expect(post).toHaveBeenCalledWith('/auth/oauth/oidc/complete-registration', {
+ invitation_code: 'invite-code',
+ adopt_display_name: false,
+ adopt_avatar: true
+ })
+ })
+})
diff --git a/frontend/src/api/__tests__/settings.authSourceDefaults.spec.ts b/frontend/src/api/__tests__/settings.authSourceDefaults.spec.ts
new file mode 100644
index 00000000..8756146e
--- /dev/null
+++ b/frontend/src/api/__tests__/settings.authSourceDefaults.spec.ts
@@ -0,0 +1,118 @@
+import { describe, expect, it } from 'vitest'
+
+import {
+ appendAuthSourceDefaultsToUpdateRequest,
+ buildAuthSourceDefaultsState,
+ type UpdateSettingsRequest,
+} from '@/api/admin/settings'
+
+describe('admin settings auth source defaults helpers', () => {
+ it('builds auth source defaults state from flat settings fields', () => {
+ const state = buildAuthSourceDefaultsState({
+ auth_source_default_email_balance: 9.5,
+ auth_source_default_email_concurrency: 3,
+ auth_source_default_email_subscriptions: [
+ { group_id: 1, validity_days: 30 },
+ ],
+ auth_source_default_email_grant_on_signup: false,
+ auth_source_default_email_grant_on_first_bind: true,
+ auth_source_default_linuxdo_balance: 6,
+ auth_source_default_linuxdo_concurrency: 8,
+ auth_source_default_linuxdo_subscriptions: [
+ { group_id: 2, validity_days: 60 },
+ ],
+ auth_source_default_linuxdo_grant_on_signup: true,
+ auth_source_default_linuxdo_grant_on_first_bind: false,
+ })
+
+ expect(state.email).toEqual({
+ balance: 9.5,
+ concurrency: 3,
+ subscriptions: [{ group_id: 1, validity_days: 30 }],
+ grant_on_signup: false,
+ grant_on_first_bind: true,
+ })
+ expect(state.linuxdo).toEqual({
+ balance: 6,
+ concurrency: 8,
+ subscriptions: [{ group_id: 2, validity_days: 60 }],
+ grant_on_signup: true,
+ grant_on_first_bind: false,
+ })
+ expect(state.oidc).toEqual({
+ balance: 0,
+ concurrency: 5,
+ subscriptions: [],
+ grant_on_signup: true,
+ grant_on_first_bind: false,
+ })
+ expect(state.wechat).toEqual({
+ balance: 0,
+ concurrency: 5,
+ subscriptions: [],
+ grant_on_signup: true,
+ grant_on_first_bind: false,
+ })
+ })
+
+ it('appends auth source defaults back onto update payload', () => {
+ const payload: UpdateSettingsRequest = {
+ site_name: 'Sub2API',
+ }
+
+ appendAuthSourceDefaultsToUpdateRequest(payload, {
+ email: {
+ balance: 1.25,
+ concurrency: 2,
+ subscriptions: [{ group_id: 3, validity_days: 7 }],
+ grant_on_signup: true,
+ grant_on_first_bind: false,
+ },
+ linuxdo: {
+ balance: 0,
+ concurrency: 6,
+ subscriptions: [],
+ grant_on_signup: false,
+ grant_on_first_bind: true,
+ },
+ oidc: {
+ balance: 4,
+ concurrency: 9,
+ subscriptions: [{ group_id: 9, validity_days: 90 }],
+ grant_on_signup: true,
+ grant_on_first_bind: true,
+ },
+ wechat: {
+ balance: 2,
+ concurrency: 5,
+ subscriptions: [],
+ grant_on_signup: false,
+ grant_on_first_bind: false,
+ },
+ })
+
+ expect(payload).toMatchObject({
+ site_name: 'Sub2API',
+ auth_source_default_email_balance: 1.25,
+ auth_source_default_email_concurrency: 2,
+ auth_source_default_email_subscriptions: [{ group_id: 3, validity_days: 7 }],
+ auth_source_default_email_grant_on_signup: true,
+ auth_source_default_email_grant_on_first_bind: false,
+ auth_source_default_linuxdo_balance: 0,
+ auth_source_default_linuxdo_concurrency: 6,
+ auth_source_default_linuxdo_subscriptions: [],
+ auth_source_default_linuxdo_grant_on_signup: false,
+ auth_source_default_linuxdo_grant_on_first_bind: true,
+ auth_source_default_oidc_balance: 4,
+ auth_source_default_oidc_concurrency: 9,
+ auth_source_default_oidc_subscriptions: [{ group_id: 9, validity_days: 90 }],
+ auth_source_default_oidc_grant_on_signup: true,
+ auth_source_default_oidc_grant_on_first_bind: true,
+ auth_source_default_wechat_balance: 2,
+ auth_source_default_wechat_concurrency: 5,
+ auth_source_default_wechat_subscriptions: [],
+ auth_source_default_wechat_grant_on_signup: false,
+ auth_source_default_wechat_grant_on_first_bind: false,
+ })
+ })
+})
diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts
index 1e4a3053..8e182c1c 100644
--- a/frontend/src/api/admin/settings.ts
+++ b/frontend/src/api/admin/settings.ts
@@ -11,6 +11,81 @@ export interface DefaultSubscriptionSetting {
validity_days: number
}
+export type AuthSourceType = 'email' | 'linuxdo' | 'oidc' | 'wechat'
+
+export interface AuthSourceDefaultsValue {
+ balance: number
+ concurrency: number
+ subscriptions: DefaultSubscriptionSetting[]
+ grant_on_signup: boolean
+ grant_on_first_bind: boolean
+}
+
+export type AuthSourceDefaultsState = Record
+
+const AUTH_SOURCE_TYPES: AuthSourceType[] = ['email', 'linuxdo', 'oidc', 'wechat']
+const AUTH_SOURCE_DEFAULT_BALANCE = 0
+const AUTH_SOURCE_DEFAULT_CONCURRENCY = 5
+
+export function normalizeDefaultSubscriptionSettings(
+ subscriptions: DefaultSubscriptionSetting[] | null | undefined
+): DefaultSubscriptionSetting[] {
+ if (!Array.isArray(subscriptions)) return []
+
+ return subscriptions
+ .filter((item) => item.group_id > 0 && item.validity_days > 0)
+ .map((item) => ({
+ group_id: Math.floor(item.group_id),
+ validity_days: Math.min(36500, Math.max(1, Math.floor(item.validity_days)))
+ }))
+}
+
+export function buildAuthSourceDefaultsState(
+ settings: Partial
+): AuthSourceDefaultsState {
+ const raw = settings as Record
+
+ return AUTH_SOURCE_TYPES.reduce((acc, source) => {
+ const subscriptions = raw[`auth_source_default_${source}_subscriptions`]
+ acc[source] = {
+ balance: Number(raw[`auth_source_default_${source}_balance`] ?? AUTH_SOURCE_DEFAULT_BALANCE),
+ concurrency: Math.max(
+ 1,
+ Number(raw[`auth_source_default_${source}_concurrency`] ?? AUTH_SOURCE_DEFAULT_CONCURRENCY)
+ ),
+ subscriptions: normalizeDefaultSubscriptionSettings(
+ Array.isArray(subscriptions) ? (subscriptions as DefaultSubscriptionSetting[]) : []
+ ),
+ grant_on_signup: raw[`auth_source_default_${source}_grant_on_signup`] !== false,
+ grant_on_first_bind: raw[`auth_source_default_${source}_grant_on_first_bind`] === true,
+ }
+ return acc
+ }, {} as AuthSourceDefaultsState)
+}
+
+export function appendAuthSourceDefaultsToUpdateRequest(
+ payload: UpdateSettingsRequest,
+ authSourceDefaults: AuthSourceDefaultsState
+): UpdateSettingsRequest {
+ const target = payload as Record
+
+ for (const source of AUTH_SOURCE_TYPES) {
+ const current = authSourceDefaults[source]
+ target[`auth_source_default_${source}_balance`] = Number(current.balance) || 0
+ target[`auth_source_default_${source}_concurrency`] = Math.max(
+ 1,
+ Math.floor(Number(current.concurrency) || AUTH_SOURCE_DEFAULT_CONCURRENCY)
+ )
+ target[`auth_source_default_${source}_subscriptions`] = normalizeDefaultSubscriptionSettings(
+ current.subscriptions
+ )
+ target[`auth_source_default_${source}_grant_on_signup`] = current.grant_on_signup
+ target[`auth_source_default_${source}_grant_on_first_bind`] = current.grant_on_first_bind
+ }
+
+ return payload
+}
+
/**
* System settings interface
*/
@@ -29,6 +104,27 @@ export interface SystemSettings {
default_balance: number
default_concurrency: number
default_subscriptions: DefaultSubscriptionSetting[]
+ auth_source_default_email_balance?: number
+ auth_source_default_email_concurrency?: number
+ auth_source_default_email_subscriptions?: DefaultSubscriptionSetting[]
+ auth_source_default_email_grant_on_signup?: boolean
+ auth_source_default_email_grant_on_first_bind?: boolean
+ auth_source_default_linuxdo_balance?: number
+ auth_source_default_linuxdo_concurrency?: number
+ auth_source_default_linuxdo_subscriptions?: DefaultSubscriptionSetting[]
+ auth_source_default_linuxdo_grant_on_signup?: boolean
+ auth_source_default_linuxdo_grant_on_first_bind?: boolean
+ auth_source_default_oidc_balance?: number
+ auth_source_default_oidc_concurrency?: number
+ auth_source_default_oidc_subscriptions?: DefaultSubscriptionSetting[]
+ auth_source_default_oidc_grant_on_signup?: boolean
+ auth_source_default_oidc_grant_on_first_bind?: boolean
+ auth_source_default_wechat_balance?: number
+ auth_source_default_wechat_concurrency?: number
+ auth_source_default_wechat_subscriptions?: DefaultSubscriptionSetting[]
+ auth_source_default_wechat_grant_on_signup?: boolean
+ auth_source_default_wechat_grant_on_first_bind?: boolean
+ force_email_on_third_party_signup?: boolean
// OEM settings
site_name: string
site_logo: string
@@ -137,6 +233,11 @@ export interface SystemSettings {
payment_cancel_rate_limit_window: number
payment_cancel_rate_limit_unit: string
payment_cancel_rate_limit_window_mode: string
+ payment_visible_method_alipay_source?: string
+ payment_visible_method_wxpay_source?: string
+ payment_visible_method_alipay_enabled?: boolean
+ payment_visible_method_wxpay_enabled?: boolean
+ openai_advanced_scheduler_enabled?: boolean
// Balance & quota notification
balance_low_notify_enabled: boolean
@@ -158,6 +259,27 @@ export interface UpdateSettingsRequest {
default_balance?: number
default_concurrency?: number
default_subscriptions?: DefaultSubscriptionSetting[]
+ auth_source_default_email_balance?: number
+ auth_source_default_email_concurrency?: number
+ auth_source_default_email_subscriptions?: DefaultSubscriptionSetting[]
+ auth_source_default_email_grant_on_signup?: boolean
+ auth_source_default_email_grant_on_first_bind?: boolean
+ auth_source_default_linuxdo_balance?: number
+ auth_source_default_linuxdo_concurrency?: number
+ auth_source_default_linuxdo_subscriptions?: DefaultSubscriptionSetting[]
+ auth_source_default_linuxdo_grant_on_signup?: boolean
+ auth_source_default_linuxdo_grant_on_first_bind?: boolean
+ auth_source_default_oidc_balance?: number
+ auth_source_default_oidc_concurrency?: number
+ auth_source_default_oidc_subscriptions?: DefaultSubscriptionSetting[]
+ auth_source_default_oidc_grant_on_signup?: boolean
+ auth_source_default_oidc_grant_on_first_bind?: boolean
+ auth_source_default_wechat_balance?: number
+ auth_source_default_wechat_concurrency?: number
+ auth_source_default_wechat_subscriptions?: DefaultSubscriptionSetting[]
+ auth_source_default_wechat_grant_on_signup?: boolean
+ auth_source_default_wechat_grant_on_first_bind?: boolean
+ force_email_on_third_party_signup?: boolean
site_name?: string
site_logo?: string
site_subtitle?: string
@@ -245,6 +367,11 @@ export interface UpdateSettingsRequest {
payment_cancel_rate_limit_window?: number
payment_cancel_rate_limit_unit?: string
payment_cancel_rate_limit_window_mode?: string
+ payment_visible_method_alipay_source?: string
+ payment_visible_method_wxpay_source?: string
+ payment_visible_method_alipay_enabled?: boolean
+ payment_visible_method_wxpay_enabled?: boolean
+ openai_advanced_scheduler_enabled?: boolean
// Balance & quota notification
balance_low_notify_enabled?: boolean
balance_low_notify_threshold?: number
diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts
index d7abcd6a..10b6ca58 100644
--- a/frontend/src/api/auth.ts
+++ b/frontend/src/api/auth.ts
@@ -198,6 +198,26 @@ export interface PendingOAuthExchangeResponse {
suggested_avatar_url?: string
}
+export interface OAuthAdoptionDecision {
+ adoptDisplayName?: boolean
+ adoptAvatar?: boolean
+}
+
+function serializeOAuthAdoptionDecision(
+ decision?: OAuthAdoptionDecision
+): Record {
+ const payload: Record = {}
+
+ if (typeof decision?.adoptDisplayName === 'boolean') {
+ payload.adopt_display_name = decision.adoptDisplayName
+ }
+ if (typeof decision?.adoptAvatar === 'boolean') {
+ payload.adopt_avatar = decision.adoptAvatar
+ }
+
+ return payload
+}
+
/**
* Refresh the access token using the refresh token
* @returns New token pair
@@ -353,7 +373,8 @@ export async function resetPassword(request: ResetPasswordRequest): Promise {
const { data } = await apiClient.post<{
access_token: string
@@ -361,7 +382,8 @@ export async function completeLinuxDoOAuthRegistration(
expires_in: number
token_type: string
}>('/auth/oauth/linuxdo/complete-registration', {
- invitation_code: invitationCode
+ invitation_code: invitationCode,
+ ...serializeOAuthAdoptionDecision(decision)
})
return data
}
@@ -372,7 +394,8 @@ export async function completeLinuxDoOAuthRegistration(
* @returns Token pair on success
*/
export async function completeOIDCOAuthRegistration(
- invitationCode: string
+ invitationCode: string,
+ decision?: OAuthAdoptionDecision
): Promise<{ access_token: string; refresh_token: string; expires_in: number; token_type: string }> {
const { data } = await apiClient.post<{
access_token: string
@@ -380,13 +403,19 @@ export async function completeOIDCOAuthRegistration(
expires_in: number
token_type: string
}>('/auth/oauth/oidc/complete-registration', {
- invitation_code: invitationCode
+ invitation_code: invitationCode,
+ ...serializeOAuthAdoptionDecision(decision)
})
return data
}
-export async function exchangePendingOAuthCompletion(): Promise {
- const { data } = await apiClient.post('/auth/oauth/pending/exchange', {})
+export async function exchangePendingOAuthCompletion(
+ decision?: OAuthAdoptionDecision
+): Promise {
+ const { data } = await apiClient.post(
+ '/auth/oauth/pending/exchange',
+ serializeOAuthAdoptionDecision(decision)
+ )
return data
}
diff --git a/frontend/src/components/auth/WechatOAuthSection.vue b/frontend/src/components/auth/WechatOAuthSection.vue
new file mode 100644
index 00000000..94e20222
--- /dev/null
+++ b/frontend/src/components/auth/WechatOAuthSection.vue
@@ -0,0 +1,53 @@
+
+
+
+
+ W
+
+ {{ t('auth.oidc.signIn', { providerName }) }}
+
+
+
+
+
+ {{ t('auth.oauthOrContinue') }}
+
+
+
+
+
+
+
diff --git a/frontend/src/components/auth/__tests__/WechatOAuthSection.spec.ts b/frontend/src/components/auth/__tests__/WechatOAuthSection.spec.ts
new file mode 100644
index 00000000..810832a0
--- /dev/null
+++ b/frontend/src/components/auth/__tests__/WechatOAuthSection.spec.ts
@@ -0,0 +1,74 @@
+import { mount } from '@vue/test-utils'
+import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
+import WechatOAuthSection from '@/components/auth/WechatOAuthSection.vue'
+
+const routeState = vi.hoisted(() => ({
+ query: {} as Record,
+}))
+
+const locationState = vi.hoisted(() => ({
+ current: { href: 'http://localhost/login' } as { href: string },
+}))
+
+vi.mock('vue-router', () => ({
+ useRoute: () => routeState,
+}))
+
+vi.mock('vue-i18n', () => ({
+ useI18n: () => ({
+ t: (key: string, params?: Record) => {
+ if (key === 'auth.oidc.signIn') {
+ return `Continue with ${params?.providerName ?? ''}`.trim()
+ }
+ if (key === 'auth.oauthOrContinue') {
+ return 'or continue'
+ }
+ return key
+ },
+ }),
+}))
+
+describe('WechatOAuthSection', () => {
+ beforeEach(() => {
+ routeState.query = { redirect: '/billing?plan=pro' }
+ locationState.current = { href: 'http://localhost/login' }
+ Object.defineProperty(window, 'location', {
+ configurable: true,
+ value: locationState.current,
+ })
+ Object.defineProperty(window.navigator, 'userAgent', {
+ configurable: true,
+ value: 'Mozilla/5.0',
+ })
+ })
+
+ afterEach(() => {
+ vi.unstubAllGlobals()
+ })
+
+ it('starts the open WeChat OAuth flow with the current redirect target', async () => {
+ const wrapper = mount(WechatOAuthSection)
+
+ expect(wrapper.text()).toContain('WeChat')
+
+ await wrapper.get('button').trigger('click')
+
+ expect(locationState.current.href).toContain(
+ '/api/v1/auth/oauth/wechat/start?mode=open&redirect=%2Fbilling%3Fplan%3Dpro'
+ )
+ })
+
+ it('uses mp mode inside the WeChat browser', async () => {
+ Object.defineProperty(window.navigator, 'userAgent', {
+ configurable: true,
+ value: 'Mozilla/5.0 MicroMessenger',
+ })
+ const wrapper = mount(WechatOAuthSection)
+
+ await wrapper.get('button').trigger('click')
+
+ expect(locationState.current.href).toContain(
+ '/api/v1/auth/oauth/wechat/start?mode=mp&redirect=%2Fbilling%3Fplan%3Dpro'
+ )
+ })
+})
diff --git a/frontend/src/components/payment/PaymentStatusPanel.vue b/frontend/src/components/payment/PaymentStatusPanel.vue
index 974dee66..8f5a5666 100644
--- a/frontend/src/components/payment/PaymentStatusPanel.vue
+++ b/frontend/src/components/payment/PaymentStatusPanel.vue
@@ -141,7 +141,9 @@ const props = defineProps<{
orderType?: string
}>()
-const emit = defineEmits<{ done: []; success: [] }>()
+type PaymentOutcome = 'success' | 'cancelled' | 'expired'
+
+const emit = defineEmits<{ done: []; success: []; settled: [outcome: PaymentOutcome] }>()
const { t } = useI18n()
const paymentStore = usePaymentStore()
@@ -154,7 +156,7 @@ const cancelling = ref(false)
const paidOrder = ref(null)
// Terminal outcome: null = still active, 'success' | 'cancelled' | 'expired'
-const outcome = ref<'success' | 'cancelled' | 'expired' | null>(null)
+const outcome = ref(null)
let pollTimer: ReturnType | null = null
let countdownTimer: ReturnType | null = null
@@ -194,10 +196,19 @@ const countdownDisplay = computed(() => {
function reopenPopup() {
if (props.payUrl) {
- window.open(props.payUrl, 'paymentPopup', POPUP_WINDOW_FEATURES)
+ const win = window.open(props.payUrl, 'paymentPopup', POPUP_WINDOW_FEATURES)
+ if (!win || win.closed) {
+ window.location.href = props.payUrl
+ }
}
}
+function setOutcome(next: PaymentOutcome) {
+ if (outcome.value === next) return
+ outcome.value = next
+ emit('settled', next)
+}
+
async function renderQR() {
await nextTick()
if (!qrCanvas.value || !qrUrl.value) return
@@ -214,23 +225,23 @@ async function pollStatus() {
if (order.status === 'COMPLETED' || order.status === 'PAID') {
cleanup()
paidOrder.value = order
- outcome.value = 'success'
+ setOutcome('success')
emit('success')
} else if (order.status === 'CANCELLED') {
cleanup()
- outcome.value = 'cancelled'
+ setOutcome('cancelled')
} else if (order.status === 'EXPIRED' || order.status === 'FAILED') {
cleanup()
- outcome.value = 'expired'
+ setOutcome('expired')
}
}
function startCountdown(seconds: number) {
remainingSeconds.value = Math.max(0, seconds)
- if (remainingSeconds.value <= 0) { outcome.value = 'expired'; return }
+ if (remainingSeconds.value <= 0) { setOutcome('expired'); return }
countdownTimer = setInterval(() => {
remainingSeconds.value--
- if (remainingSeconds.value <= 0) { outcome.value = 'expired'; cleanup() }
+ if (remainingSeconds.value <= 0) { setOutcome('expired'); cleanup() }
}, 1000)
}
@@ -240,7 +251,7 @@ async function handleCancel() {
try {
await paymentAPI.cancelOrder(props.orderId)
cleanup()
- outcome.value = 'cancelled'
+ setOutcome('cancelled')
} catch (err: unknown) {
appStore.showError(extractApiErrorMessage(err, t('common.error')))
} finally {
diff --git a/frontend/src/components/payment/__tests__/paymentFlow.spec.ts b/frontend/src/components/payment/__tests__/paymentFlow.spec.ts
new file mode 100644
index 00000000..f5212f15
--- /dev/null
+++ b/frontend/src/components/payment/__tests__/paymentFlow.spec.ts
@@ -0,0 +1,163 @@
+import { describe, expect, it } from 'vitest'
+import type { CreateOrderResult, MethodLimit } from '@/types/payment'
+import {
+ decidePaymentLaunch,
+ getVisibleMethods,
+ readPaymentRecoverySnapshot,
+ type PaymentRecoverySnapshot,
+} from '@/components/payment/paymentFlow'
+
+function methodLimit(overrides: Partial = {}): MethodLimit {
+ return {
+ daily_limit: 0,
+ daily_used: 0,
+ daily_remaining: 0,
+ single_min: 0,
+ single_max: 0,
+ fee_rate: 0,
+ available: true,
+ ...overrides,
+ }
+}
+
+function createOrderResult(overrides: Partial = {}): CreateOrderResult {
+ return {
+ order_id: 101,
+ amount: 88,
+ pay_amount: 88,
+ fee_rate: 0,
+ expires_at: '2099-01-01T00:10:00.000Z',
+ ...overrides,
+ }
+}
+
+describe('getVisibleMethods', () => {
+ it('filters hidden provider methods and normalizes aliases', () => {
+ const visible = getVisibleMethods({
+ alipay_direct: methodLimit({ single_min: 5 }),
+ wxpay: methodLimit({ single_max: 100 }),
+ stripe: methodLimit({ fee_rate: 3 }),
+ })
+
+ expect(visible).toEqual({
+ alipay: methodLimit({ single_min: 5 }),
+ wxpay: methodLimit({ single_max: 100 }),
+ })
+ })
+
+ it('prefers canonical visible methods over aliases when both exist', () => {
+ const visible = getVisibleMethods({
+ alipay: methodLimit({ single_min: 2 }),
+ alipay_direct: methodLimit({ single_min: 9 }),
+ wxpay_direct: methodLimit({ fee_rate: 1.2 }),
+ })
+
+ expect(visible.alipay.single_min).toBe(2)
+ expect(visible.wxpay.fee_rate).toBe(1.2)
+ })
+})
+
+describe('decidePaymentLaunch', () => {
+ it('uses Stripe popup waiting flow for desktop Alipay client secret', () => {
+ const decision = decidePaymentLaunch(createOrderResult({
+ client_secret: 'cs_test',
+ resume_token: 'resume-1',
+ }), {
+ visibleMethod: 'alipay',
+ orderType: 'balance',
+ isMobile: false,
+ })
+
+ expect(decision.kind).toBe('stripe_popup')
+ expect(decision.paymentState.paymentType).toBe('alipay')
+ expect(decision.stripeMethod).toBe('alipay')
+ expect(decision.recovery.resumeToken).toBe('resume-1')
+ })
+
+ it('uses Stripe route flow for mobile WeChat client secret', () => {
+ const decision = decidePaymentLaunch(createOrderResult({
+ client_secret: 'cs_test',
+ }), {
+ visibleMethod: 'wxpay',
+ orderType: 'subscription',
+ isMobile: true,
+ })
+
+ expect(decision.kind).toBe('stripe_route')
+ expect(decision.stripeMethod).toBe('wechat_pay')
+ expect(decision.paymentState.orderType).toBe('subscription')
+ })
+
+ it('keeps hosted redirect metadata for recovery flows', () => {
+ const decision = decidePaymentLaunch(createOrderResult({
+ pay_url: 'https://pay.example.com/session/abc',
+ payment_mode: 'popup',
+ resume_token: 'resume-2',
+ }), {
+ visibleMethod: 'wxpay',
+ orderType: 'balance',
+ isMobile: false,
+ })
+
+ expect(decision.kind).toBe('redirect_waiting')
+ expect(decision.paymentState.payUrl).toBe('https://pay.example.com/session/abc')
+ expect(decision.recovery.paymentMode).toBe('popup')
+ expect(decision.recovery.resumeToken).toBe('resume-2')
+ })
+})
+
+describe('readPaymentRecoverySnapshot', () => {
+ it('restores an unexpired snapshot when the resume token matches', () => {
+ const snapshot: PaymentRecoverySnapshot = {
+ orderId: 33,
+ amount: 18,
+ qrCode: '',
+ expiresAt: '2099-01-01T00:10:00.000Z',
+ paymentType: 'alipay',
+ payUrl: 'https://pay.example.com/session/33',
+ clientSecret: '',
+ payAmount: 18,
+ orderType: 'balance',
+ paymentMode: 'popup',
+ resumeToken: 'resume-33',
+ createdAt: Date.UTC(2099, 0, 1, 0, 0, 0),
+ }
+
+ const restored = readPaymentRecoverySnapshot(JSON.stringify(snapshot), {
+ now: Date.UTC(2099, 0, 1, 0, 1, 0),
+ resumeToken: 'resume-33',
+ })
+
+ expect(restored?.orderId).toBe(33)
+ })
+
+ it('drops expired or mismatched recovery snapshots', () => {
+ const expiredSnapshot: PaymentRecoverySnapshot = {
+ orderId: 55,
+ amount: 18,
+ qrCode: '',
+ expiresAt: '2024-01-01T00:10:00.000Z',
+ paymentType: 'wxpay',
+ payUrl: 'https://pay.example.com/session/55',
+ clientSecret: '',
+ payAmount: 18,
+ orderType: 'balance',
+ paymentMode: 'popup',
+ resumeToken: 'resume-55',
+ createdAt: Date.UTC(2024, 0, 1, 0, 0, 0),
+ }
+
+ expect(readPaymentRecoverySnapshot(JSON.stringify(expiredSnapshot), {
+ now: Date.UTC(2024, 0, 1, 0, 20, 0),
+ resumeToken: 'resume-55',
+ })).toBeNull()
+
+ expect(readPaymentRecoverySnapshot(JSON.stringify({
+ ...expiredSnapshot,
+ expiresAt: '2099-01-01T00:10:00.000Z',
+ }), {
+ now: Date.UTC(2099, 0, 1, 0, 1, 0),
+ resumeToken: 'other-token',
+ })).toBeNull()
+ })
+})
diff --git a/frontend/src/components/payment/paymentFlow.ts b/frontend/src/components/payment/paymentFlow.ts
new file mode 100644
index 00000000..70225a0c
--- /dev/null
+++ b/frontend/src/components/payment/paymentFlow.ts
@@ -0,0 +1,197 @@
+import type { CreateOrderResult, MethodLimit, OrderType } from '@/types/payment'
+
+export const PAYMENT_RECOVERY_STORAGE_KEY = 'payment.recovery.current'
+
+const VISIBLE_METHOD_ALIASES = {
+ alipay: 'alipay',
+ alipay_direct: 'alipay',
+ wxpay: 'wxpay',
+ wxpay_direct: 'wxpay',
+} as const
+
+export type VisiblePaymentMethod = 'alipay' | 'wxpay'
+export type StripeVisibleMethod = 'alipay' | 'wechat_pay'
+export type PaymentLaunchKind =
+ | 'qr_waiting'
+ | 'redirect_waiting'
+ | 'stripe_popup'
+ | 'stripe_route'
+ | 'unhandled'
+
+export interface PaymentRecoverySnapshot {
+ orderId: number
+ amount: number
+ qrCode: string
+ expiresAt: string
+ paymentType: string
+ payUrl: string
+ clientSecret: string
+ payAmount: number
+ orderType: OrderType | ''
+ paymentMode: string
+ resumeToken: string
+ createdAt: number
+}
+
+export interface PaymentLaunchContext {
+ visibleMethod: string
+ orderType: OrderType
+ isMobile: boolean
+ now?: number
+ stripePopupUrl?: string
+ stripeRouteUrl?: string
+}
+
+export interface PaymentLaunchDecision {
+ kind: PaymentLaunchKind
+ paymentState: PaymentRecoverySnapshot
+ recovery: PaymentRecoverySnapshot
+ stripeMethod?: StripeVisibleMethod
+}
+
+type CreateOrderFlowResult = CreateOrderResult & {
+ resume_token?: string
+}
+
+type StorageWriter = Pick
+
+export function normalizeVisibleMethod(method: string): VisiblePaymentMethod | '' {
+ const normalized = VISIBLE_METHOD_ALIASES[method.trim() as keyof typeof VISIBLE_METHOD_ALIASES]
+ return normalized ?? ''
+}
+
+export function getVisibleMethods(methods: Record): Record {
+ const visible: Record = {}
+
+ Object.entries(methods).forEach(([type, limit]) => {
+ const normalized = normalizeVisibleMethod(type)
+ if (!normalized) return
+
+ const isCanonical = type === normalized
+ const existing = visible[normalized]
+ if (!existing || isCanonical) {
+ visible[normalized] = { ...limit }
+ }
+ })
+
+ return visible
+}
+
+export function decidePaymentLaunch(
+ result: CreateOrderFlowResult,
+ context: PaymentLaunchContext,
+): PaymentLaunchDecision {
+ const visibleMethod = normalizeVisibleMethod(context.visibleMethod) || context.visibleMethod
+ const baseState = createPaymentRecoverySnapshot({
+ orderId: result.order_id,
+ amount: result.amount,
+ qrCode: result.qr_code || '',
+ expiresAt: result.expires_at || '',
+ paymentType: visibleMethod,
+ payUrl: result.pay_url || '',
+ clientSecret: result.client_secret || '',
+ payAmount: result.pay_amount,
+ orderType: context.orderType,
+ paymentMode: (result.payment_mode || '').trim(),
+ resumeToken: result.resume_token || '',
+ }, context.now)
+
+ if (baseState.clientSecret) {
+ const stripeMethod: StripeVisibleMethod = visibleMethod === 'wxpay' ? 'wechat_pay' : 'alipay'
+ const kind: PaymentLaunchKind = stripeMethod === 'alipay' && !context.isMobile
+ ? 'stripe_popup'
+ : 'stripe_route'
+ const payUrl = kind === 'stripe_popup'
+ ? context.stripePopupUrl || context.stripeRouteUrl || ''
+ : context.stripeRouteUrl || context.stripePopupUrl || ''
+ const paymentState = { ...baseState, payUrl }
+ return { kind, paymentState, recovery: paymentState, stripeMethod }
+ }
+
+ if (baseState.qrCode) {
+ return { kind: 'qr_waiting', paymentState: baseState, recovery: baseState }
+ }
+
+ if (baseState.payUrl) {
+ return { kind: 'redirect_waiting', paymentState: baseState, recovery: baseState }
+ }
+
+ return { kind: 'unhandled', paymentState: baseState, recovery: baseState }
+}
+
+export function createPaymentRecoverySnapshot(
+ state: Omit,
+ now = Date.now(),
+): PaymentRecoverySnapshot {
+ return {
+ ...state,
+ createdAt: now,
+ }
+}
+
+export function writePaymentRecoverySnapshot(
+ storage: StorageWriter,
+ snapshot: PaymentRecoverySnapshot,
+ key = PAYMENT_RECOVERY_STORAGE_KEY,
+): void {
+ storage.setItem(key, JSON.stringify(snapshot))
+}
+
+export function clearPaymentRecoverySnapshot(
+ storage: Pick,
+ key = PAYMENT_RECOVERY_STORAGE_KEY,
+): void {
+ storage.removeItem(key)
+}
+
+export function readPaymentRecoverySnapshot(
+ raw: string | null | undefined,
+ options: { now?: number; resumeToken?: string } = {},
+): PaymentRecoverySnapshot | null {
+ if (!raw) return null
+
+ try {
+ const parsed = JSON.parse(raw) as Partial
+ if (
+ typeof parsed.orderId !== 'number'
+ || typeof parsed.amount !== 'number'
+ || typeof parsed.qrCode !== 'string'
+ || typeof parsed.expiresAt !== 'string'
+ || typeof parsed.paymentType !== 'string'
+ || typeof parsed.payUrl !== 'string'
+ || typeof parsed.clientSecret !== 'string'
+ || typeof parsed.payAmount !== 'number'
+ || typeof parsed.paymentMode !== 'string'
+ || typeof parsed.resumeToken !== 'string'
+ || typeof parsed.createdAt !== 'number'
+ ) {
+ return null
+ }
+
+ const now = options.now ?? Date.now()
+ const expiresAt = Date.parse(parsed.expiresAt)
+ if (Number.isFinite(expiresAt) && expiresAt <= now) {
+ return null
+ }
+ if (options.resumeToken && parsed.resumeToken && parsed.resumeToken !== options.resumeToken) {
+ return null
+ }
+
+ return {
+ orderId: parsed.orderId,
+ amount: parsed.amount,
+ qrCode: parsed.qrCode,
+ expiresAt: parsed.expiresAt,
+ paymentType: parsed.paymentType,
+ payUrl: parsed.payUrl,
+ clientSecret: parsed.clientSecret,
+ payAmount: parsed.payAmount,
+ orderType: parsed.orderType === 'subscription' ? 'subscription' : 'balance',
+ paymentMode: parsed.paymentMode,
+ resumeToken: parsed.resumeToken,
+ createdAt: parsed.createdAt,
+ }
+ } catch {
+ return null
+ }
+}
diff --git a/frontend/src/router/__tests__/wechat-route.spec.ts b/frontend/src/router/__tests__/wechat-route.spec.ts
new file mode 100644
index 00000000..84b20452
--- /dev/null
+++ b/frontend/src/router/__tests__/wechat-route.spec.ts
@@ -0,0 +1,55 @@
+import { describe, expect, it, vi } from 'vitest'
+
+const authStore = vi.hoisted(() => ({
+ checkAuth: vi.fn(),
+ isAuthenticated: false,
+ isAdmin: false,
+ isSimpleMode: false,
+}))
+
+const appStore = vi.hoisted(() => ({
+ siteName: 'Sub2API',
+ backendModeEnabled: false,
+ cachedPublicSettings: null as null | Record,
+}))
+
+vi.mock('@/stores/auth', () => ({
+ useAuthStore: () => authStore,
+}))
+
+vi.mock('@/stores/app', () => ({
+ useAppStore: () => appStore,
+}))
+
+vi.mock('@/stores/adminSettings', () => ({
+ useAdminSettingsStore: () => ({
+ customMenuItems: [],
+ }),
+}))
+
+vi.mock('@/composables/useNavigationLoading', () => ({
+ useNavigationLoadingState: () => ({
+ startNavigation: vi.fn(),
+ endNavigation: vi.fn(),
+ isLoading: { value: false },
+ }),
+}))
+
+vi.mock('@/composables/useRoutePrefetch', () => ({
+ useRoutePrefetch: () => ({
+ triggerPrefetch: vi.fn(),
+ cancelPendingPrefetch: vi.fn(),
+ resetPrefetchState: vi.fn(),
+ }),
+}))
+
+describe('router WeChat OAuth route', () => {
+ it('registers the WeChat callback route as a public route', async () => {
+ const { default: router } = await import('@/router')
+ const route = router.getRoutes().find((record) => record.name === 'WeChatOAuthCallback')
+
+ expect(route?.path).toBe('/auth/wechat/callback')
+ expect(route?.meta.requiresAuth).toBe(false)
+ expect(route?.meta.title).toBe('WeChat OAuth Callback')
+ })
+})
diff --git a/frontend/src/router/index.ts b/frontend/src/router/index.ts
index ad6e71c4..beaa1da2 100644
--- a/frontend/src/router/index.ts
+++ b/frontend/src/router/index.ts
@@ -83,6 +83,15 @@ const routes: RouteRecordRaw[] = [
title: 'LinuxDo OAuth Callback'
}
},
+ {
+ path: '/auth/wechat/callback',
+ name: 'WeChatOAuthCallback',
+ component: () => import('@/views/auth/WechatCallbackView.vue'),
+ meta: {
+ requiresAuth: false,
+ title: 'WeChat OAuth Callback'
+ }
+ },
{
path: '/auth/oidc/callback',
name: 'OIDCOAuthCallback',
diff --git a/frontend/src/stores/app.ts b/frontend/src/stores/app.ts
index 1995383d..1b1af87b 100644
--- a/frontend/src/stores/app.ts
+++ b/frontend/src/stores/app.ts
@@ -336,6 +336,7 @@ export const useAppStore = defineStore('app', () => {
custom_menu_items: [],
custom_endpoints: [],
linuxdo_oauth_enabled: false,
+ wechat_oauth_enabled: false,
oidc_oauth_enabled: false,
oidc_oauth_provider_name: 'OIDC',
backend_mode_enabled: false,
diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts
index 89fd777f..529eff55 100644
--- a/frontend/src/types/index.ts
+++ b/frontend/src/types/index.ts
@@ -123,6 +123,7 @@ export interface PublicSettings {
custom_menu_items: CustomMenuItem[]
custom_endpoints: CustomEndpoint[]
linuxdo_oauth_enabled: boolean
+ wechat_oauth_enabled: boolean
oidc_oauth_enabled: boolean
oidc_oauth_provider_name: string
backend_mode_enabled: boolean
diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue
index 8bfa0f2b..0d23baa5 100644
--- a/frontend/src/views/admin/SettingsView.vue
+++ b/frontend/src/views/admin/SettingsView.vue
@@ -1586,6 +1586,221 @@
+
+
+
+
+ {{ localText('认证来源默认值', 'Auth Source Defaults') }}
+
+
+ {{
+ localText(
+ '按注册来源配置新用户默认余额、并发、订阅与授权策略。',
+ 'Configure per-source default balance, concurrency, subscriptions, and grant rules.'
+ )
+ }}
+
+
+
+
+
+
+ {{ localText('第三方注册强制补充邮箱', 'Require email on third-party signup') }}
+
+
+ {{
+ localText(
+ '启用后,Linux DO、OIDC、微信注册缺少邮箱时必须先补充邮箱地址。',
+ 'When enabled, Linux DO, OIDC, and WeChat signups must provide an email before account creation.'
+ )
+ }}
+
+
+
+
+
+
+
+
+
{{ authSource.title }}
+
+ {{ authSource.description }}
+
+
+
+
+
+
+
+
+
+ {{ localText('注册即授权', 'Grant on signup') }}
+
+
+ {{
+ localText(
+ '来源首次注册成功后立即发放默认权益。',
+ 'Grant default entitlements immediately after signup.'
+ )
+ }}
+
+
+
+
+
+
+
+
+ {{ localText('首次绑定时授权', 'Grant on first bind') }}
+
+
+ {{
+ localText(
+ '来源首次绑定到现有账号时发放默认权益。',
+ 'Grant default entitlements when the source is first bound to an existing user.'
+ )
+ }}
+
+
+
+
+
+
+
+
+
+
+ {{ localText('默认订阅', 'Default subscriptions') }}
+
+
+ {{
+ localText(
+ '仅对当前认证来源生效,未配置时不追加来源专属订阅。',
+ 'Applies only to this auth source. Leave empty to skip source-specific subscriptions.'
+ )
+ }}
+
+
+
+ {{ t('admin.settings.defaults.addDefaultSubscription') }}
+
+
+
+
+ {{
+ localText(
+ '当前来源未配置专属默认订阅。',
+ 'No source-specific default subscriptions configured.'
+ )
+ }}
+
+
+
+
+
+
+ {{ t('admin.settings.defaults.subscriptionGroup') }}
+
+
+
+
+
+ {{ t('admin.settings.defaults.subscriptionGroup') }}
+
+
+
+
+
+
+
+
+
+ {{ t('admin.settings.defaults.subscriptionValidityDays') }}
+
+
+
+
+
+ {{ t('common.delete') }}
+
+
+
+
+
+
+
+
+
@@ -1643,19 +1858,38 @@
-
-
-
- {{ t('admin.settings.scheduling.allowUngroupedKey') }}
+
+
+
+
+ {{ t('admin.settings.scheduling.allowUngroupedKey') }}
+
+
+ {{ t('admin.settings.scheduling.allowUngroupedKeyHint') }}
+
+
+
+
+
-
- {{ t('admin.settings.scheduling.allowUngroupedKeyHint') }}
-
-
-
-
-
+
+
+
+
+ {{ localText('OpenAI 高级调度器', 'OpenAI advanced scheduler') }}
+
+
+ {{
+ localText(
+ '切换 OpenAI 侧新增的高级调度开关,供当前分支实验性调度逻辑使用。',
+ 'Toggles the new OpenAI advanced scheduler flag for the experimental routing logic on this branch.'
+ )
+ }}
+
+
+
+
@@ -2450,6 +2684,59 @@
+
+
+
+
+
+ {{
+ localText(
+ `${visibleMethod.title} 可见方式`,
+ `${visibleMethod.title} visible method`
+ )
+ }}
+
+
+ {{
+ localText(
+ '控制前台结算页是否展示该方式,以及展示时使用的来源键。',
+ 'Controls whether checkout shows this method and which source key it exposes.'
+ )
+ }}
+
+
+
+
+
+
+
+ {{ localText('来源键', 'Source key') }}
+
+
+
+ {{
+ localText(
+ '留空表示由后端使用默认来源;可填 easypay、alipay、wxpay 等来源标识。',
+ 'Leave blank to let the backend decide. Typical values are easypay, alipay, or wxpay.'
+ )
+ }}
+
+
+
+
@@ -2827,7 +3114,14 @@
import { ref, reactive, computed, onMounted } from 'vue'
import { useI18n } from 'vue-i18n'
import { adminAPI } from '@/api'
+import {
+ appendAuthSourceDefaultsToUpdateRequest,
+ buildAuthSourceDefaultsState,
+ normalizeDefaultSubscriptionSettings,
+} from '@/api/admin/settings'
import type {
+ AuthSourceDefaultsState,
+ AuthSourceType,
SystemSettings,
UpdateSettingsRequest,
DefaultSubscriptionSetting,
@@ -2864,6 +3158,10 @@ const { t, locale } = useI18n()
const appStore = useAppStore()
const adminSettingsStore = useAdminSettingsStore()
+function localText(zh: string, en: string): string {
+ return locale.value.startsWith('zh') ? zh : en
+}
+
type SettingsTab = 'general' | 'security' | 'users' | 'gateway' | 'payment' | 'email' | 'backup'
const activeTab = ref
('general')
const settingsTabs = [
@@ -2960,6 +3258,12 @@ type SettingsForm = SystemSettings & {
turnstile_secret_key: string
linuxdo_connect_client_secret: string
oidc_connect_client_secret: string
+ force_email_on_third_party_signup: boolean
+ payment_visible_method_alipay_source: string
+ payment_visible_method_wxpay_source: string
+ payment_visible_method_alipay_enabled: boolean
+ payment_visible_method_wxpay_enabled: boolean
+ openai_advanced_scheduler_enabled: boolean
}
const form = reactive({
@@ -2974,6 +3278,7 @@ const form = reactive({
default_balance: 0,
default_concurrency: 1,
default_subscriptions: [],
+ force_email_on_third_party_signup: false,
site_name: 'Sub2API',
site_logo: '',
site_subtitle: 'Subscription to API Conversion Platform',
@@ -2983,7 +3288,7 @@ const form = reactive({
home_content: '',
backend_mode_enabled: false,
hide_ccs_import_button: false,
- payment_enabled: false, payment_min_amount: 1, payment_max_amount: 10000, payment_daily_limit: 50000, payment_max_pending_orders: 3, payment_order_timeout_minutes: 30, payment_balance_disabled: false, payment_balance_recharge_multiplier: 1, payment_recharge_fee_rate: 0, payment_enabled_types: [], payment_help_image_url: '', payment_help_text: '', payment_product_name_prefix: '', payment_product_name_suffix: '', payment_load_balance_strategy: 'round-robin', payment_cancel_rate_limit_enabled: false, payment_cancel_rate_limit_max: 10, payment_cancel_rate_limit_window: 1, payment_cancel_rate_limit_unit: 'day', payment_cancel_rate_limit_window_mode: 'rolling',
+ payment_enabled: false, payment_min_amount: 1, payment_max_amount: 10000, payment_daily_limit: 50000, payment_max_pending_orders: 3, payment_order_timeout_minutes: 30, payment_balance_disabled: false, payment_balance_recharge_multiplier: 1, payment_recharge_fee_rate: 0, payment_enabled_types: [], payment_help_image_url: '', payment_help_text: '', payment_product_name_prefix: '', payment_product_name_suffix: '', payment_load_balance_strategy: 'round-robin', payment_cancel_rate_limit_enabled: false, payment_cancel_rate_limit_max: 10, payment_cancel_rate_limit_window: 1, payment_cancel_rate_limit_unit: 'day', payment_cancel_rate_limit_window_mode: 'rolling', payment_visible_method_alipay_source: '', payment_visible_method_wxpay_source: '', payment_visible_method_alipay_enabled: false, payment_visible_method_wxpay_enabled: false,
table_default_page_size: tablePageSizeDefault,
table_page_size_options: [10, 20, 50, 100],
custom_menu_items: [] as Array<{id: string; label: string; icon_svg: string; url: string; visibility: 'user' | 'admin'; sort_order: number}>,
@@ -3051,6 +3356,7 @@ const form = reactive({
max_claude_code_version: '',
// 分组隔离
allow_ungrouped_key_scheduling: false,
+ openai_advanced_scheduler_enabled: false,
// Gateway forwarding behavior
enable_fingerprint_unification: true,
enable_metadata_passthrough: false,
@@ -3063,6 +3369,74 @@ const form = reactive({
account_quota_notify_emails: [] as NotifyEmailEntry[]
})
+const authSourceDefaults = reactive(buildAuthSourceDefaultsState({}))
+
+const authSourceDefaultsMeta = computed(() => [
+ {
+ source: 'email' as AuthSourceType,
+ title: localText('邮箱注册', 'Email signup'),
+ description: localText('适用于邮箱密码注册的新用户默认配额。', 'Default quota grants for email-password signups.')
+ },
+ {
+ source: 'linuxdo' as AuthSourceType,
+ title: localText('Linux DO 登录', 'Linux DO signup'),
+ description: localText('适用于 Linux DO 第三方注册的新用户默认配额。', 'Default quota grants for Linux DO signups.')
+ },
+ {
+ source: 'oidc' as AuthSourceType,
+ title: localText('OIDC 登录', 'OIDC signup'),
+ description: localText('适用于 OIDC 第三方注册的新用户默认配额。', 'Default quota grants for OIDC signups.')
+ },
+ {
+ source: 'wechat' as AuthSourceType,
+ title: localText('微信登录', 'WeChat signup'),
+ description: localText('适用于微信第三方注册的新用户默认配额。', 'Default quota grants for WeChat signups.')
+ },
+])
+
+const paymentVisibleMethodCards = computed(() => [
+ {
+ key: 'alipay' as const,
+ title: t('payment.methods.alipay'),
+ enabledField: 'payment_visible_method_alipay_enabled' as const,
+ sourceField: 'payment_visible_method_alipay_source' as const,
+ },
+ {
+ key: 'wxpay' as const,
+ title: t('payment.methods.wxpay'),
+ enabledField: 'payment_visible_method_wxpay_enabled' as const,
+ sourceField: 'payment_visible_method_wxpay_source' as const,
+ },
+])
+
+function getPaymentVisibleMethodEnabled(method: 'alipay' | 'wxpay'): boolean {
+ return method === 'alipay'
+ ? form.payment_visible_method_alipay_enabled
+ : form.payment_visible_method_wxpay_enabled
+}
+
+function setPaymentVisibleMethodEnabled(method: 'alipay' | 'wxpay', enabled: boolean) {
+ if (method === 'alipay') {
+ form.payment_visible_method_alipay_enabled = enabled
+ return
+ }
+ form.payment_visible_method_wxpay_enabled = enabled
+}
+
+function getPaymentVisibleMethodSource(method: 'alipay' | 'wxpay'): string {
+ return method === 'alipay'
+ ? form.payment_visible_method_alipay_source
+ : form.payment_visible_method_wxpay_source
+}
+
+function setPaymentVisibleMethodSource(method: 'alipay' | 'wxpay', source: string) {
+ if (method === 'alipay') {
+ form.payment_visible_method_alipay_source = source
+ return
+ }
+ form.payment_visible_method_wxpay_source = source
+}
+
// Proxies for web search emulation ProxySelector
const webSearchProxies = ref([])
@@ -3428,15 +3802,9 @@ async function loadSettings() {
(form as Record)[key] = value
}
}
+ Object.assign(authSourceDefaults, buildAuthSourceDefaultsState(settings))
form.backend_mode_enabled = settings.backend_mode_enabled
- form.default_subscriptions = Array.isArray(settings.default_subscriptions)
- ? settings.default_subscriptions
- .filter((item) => item.group_id > 0 && item.validity_days > 0)
- .map((item) => ({
- group_id: item.group_id,
- validity_days: item.validity_days
- }))
- : []
+ form.default_subscriptions = normalizeDefaultSubscriptionSettings(settings.default_subscriptions)
registrationEmailSuffixWhitelistTags.value = normalizeRegistrationEmailSuffixDomains(
settings.registration_email_suffix_whitelist
)
@@ -3471,10 +3839,18 @@ async function loadSubscriptionGroups() {
}
}
+function findNextAvailableSubscriptionGroup(
+ existingGroupIDs: number[]
+): AdminGroup | undefined {
+ const existing = new Set(existingGroupIDs)
+ return subscriptionGroups.value.find((group) => !existing.has(group.id))
+}
+
function addDefaultSubscription() {
if (subscriptionGroups.value.length === 0) return
- const existing = new Set(form.default_subscriptions.map((item) => item.group_id))
- const candidate = subscriptionGroups.value.find((group) => !existing.has(group.id))
+ const candidate = findNextAvailableSubscriptionGroup(
+ form.default_subscriptions.map((item) => item.group_id)
+ )
if (!candidate) return
form.default_subscriptions.push({
group_id: candidate.id,
@@ -3486,6 +3862,36 @@ function removeDefaultSubscription(index: number) {
form.default_subscriptions.splice(index, 1)
}
+function addAuthSourceDefaultSubscription(source: AuthSourceType) {
+ if (subscriptionGroups.value.length === 0) return
+ const candidate = findNextAvailableSubscriptionGroup(
+ authSourceDefaults[source].subscriptions.map((item) => item.group_id)
+ )
+ if (!candidate) return
+ authSourceDefaults[source].subscriptions.push({
+ group_id: candidate.id,
+ validity_days: 30
+ })
+}
+
+function removeAuthSourceDefaultSubscription(source: AuthSourceType, index: number) {
+ authSourceDefaults[source].subscriptions.splice(index, 1)
+}
+
+function findDuplicateDefaultSubscription(
+ subscriptions: DefaultSubscriptionSetting[]
+): DefaultSubscriptionSetting | undefined {
+ const seenGroupIDs = new Set()
+
+ return subscriptions.find((item) => {
+ if (seenGroupIDs.has(item.group_id)) {
+ return true
+ }
+ seenGroupIDs.add(item.group_id)
+ return false
+ })
+}
+
async function saveSettings() {
saving.value = true
try {
@@ -3520,21 +3926,12 @@ async function saveSettings() {
form.table_default_page_size = normalizedTableDefaultPageSize
form.table_page_size_options = normalizedTablePageSizeOptions
- const normalizedDefaultSubscriptions = form.default_subscriptions
- .filter((item) => item.group_id > 0 && item.validity_days > 0)
- .map((item: DefaultSubscriptionSetting) => ({
- group_id: item.group_id,
- validity_days: Math.min(36500, Math.max(1, Math.floor(item.validity_days)))
- }))
-
- const seenGroupIDs = new Set()
- const duplicateDefaultSubscription = normalizedDefaultSubscriptions.find((item) => {
- if (seenGroupIDs.has(item.group_id)) {
- return true
- }
- seenGroupIDs.add(item.group_id)
- return false
- })
+ const normalizedDefaultSubscriptions = normalizeDefaultSubscriptionSettings(
+ form.default_subscriptions
+ )
+ const duplicateDefaultSubscription = findDuplicateDefaultSubscription(
+ normalizedDefaultSubscriptions
+ )
if (duplicateDefaultSubscription) {
appStore.showError(
t('admin.settings.defaults.defaultSubscriptionsDuplicate', {
@@ -3544,6 +3941,23 @@ async function saveSettings() {
return
}
+ for (const authSource of authSourceDefaultsMeta.value) {
+ authSourceDefaults[authSource.source].subscriptions = normalizeDefaultSubscriptionSettings(
+ authSourceDefaults[authSource.source].subscriptions
+ )
+ const duplicate = findDuplicateDefaultSubscription(
+ authSourceDefaults[authSource.source].subscriptions
+ )
+ if (duplicate) {
+ appStore.showError(
+ `${authSource.title}: ${t('admin.settings.defaults.defaultSubscriptionsDuplicate', {
+ groupId: duplicate.group_id
+ })}`
+ )
+ return
+ }
+ }
+
// Validate URL fields — novalidate disables browser-native checks, so we validate here
const isValidHttpUrl = (url: string): boolean => {
if (!url) return true
@@ -3571,6 +3985,7 @@ async function saveSettings() {
default_balance: form.default_balance,
default_concurrency: form.default_concurrency,
default_subscriptions: normalizedDefaultSubscriptions,
+ force_email_on_third_party_signup: form.force_email_on_third_party_signup,
site_name: form.site_name,
site_logo: form.site_logo,
site_subtitle: form.site_subtitle,
@@ -3655,6 +4070,11 @@ async function saveSettings() {
payment_cancel_rate_limit_window: Number(form.payment_cancel_rate_limit_window) || 1,
payment_cancel_rate_limit_unit: form.payment_cancel_rate_limit_unit,
payment_cancel_rate_limit_window_mode: form.payment_cancel_rate_limit_window_mode,
+ payment_visible_method_alipay_source: form.payment_visible_method_alipay_source,
+ payment_visible_method_wxpay_source: form.payment_visible_method_wxpay_source,
+ payment_visible_method_alipay_enabled: form.payment_visible_method_alipay_enabled,
+ payment_visible_method_wxpay_enabled: form.payment_visible_method_wxpay_enabled,
+ openai_advanced_scheduler_enabled: form.openai_advanced_scheduler_enabled,
// Balance & quota notification
balance_low_notify_enabled: form.balance_low_notify_enabled,
balance_low_notify_threshold: Number(form.balance_low_notify_threshold) || 0,
@@ -3663,12 +4083,15 @@ async function saveSettings() {
account_quota_notify_emails: (form.account_quota_notify_emails || []).filter((e) => e.email.trim() !== ''),
}
+ appendAuthSourceDefaultsToUpdateRequest(payload, authSourceDefaults)
+
const updated = await adminAPI.settings.updateSettings(payload)
for (const [key, value] of Object.entries(updated)) {
if (value !== null && value !== undefined) {
(form as Record)[key] = value
}
}
+ Object.assign(authSourceDefaults, buildAuthSourceDefaultsState(updated))
registrationEmailSuffixWhitelistTags.value = normalizeRegistrationEmailSuffixDomains(
updated.registration_email_suffix_whitelist
)
diff --git a/frontend/src/views/auth/LinuxDoCallbackView.vue b/frontend/src/views/auth/LinuxDoCallbackView.vue
index af48959b..0a125def 100644
--- a/frontend/src/views/auth/LinuxDoCallbackView.vue
+++ b/frontend/src/views/auth/LinuxDoCallbackView.vue
@@ -11,32 +11,94 @@
-
-
- {{ t('auth.linuxdo.invitationRequired') }}
-
-
-
+
+
+
+
+
+ Use LinuxDo profile details
+
+
+ Choose whether to apply the nickname or avatar from LinuxDo to this account.
+
+
+
+
+
+
+
+ Use display name
+
+
+ {{ suggestedDisplayName }}
+
+
+
+
+
+
+
+
+
+ Use avatar
+
+
+ {{ suggestedAvatarUrl }}
+
+
+
+
-
-
- {{ invitationError }}
+
+
+
+ {{ t('auth.linuxdo.invitationRequired') }}
-
-
- {{ isSubmitting ? t('auth.linuxdo.completing') : t('auth.linuxdo.completeRegistration') }}
-
+
+
+
+
+
+ {{ invitationError }}
+
+
+
+ {{ isSubmitting ? t('auth.linuxdo.completing') : t('auth.linuxdo.completeRegistration') }}
+
+
+
+
+
+ Review the LinuxDo profile details before continuing.
+
+
+ {{ isSubmitting ? t('common.processing') : 'Continue' }}
+
+
@@ -71,7 +133,12 @@ import { useI18n } from 'vue-i18n'
import { AuthLayout } from '@/components/layout'
import Icon from '@/components/icons/Icon.vue'
import { useAuthStore, useAppStore } from '@/stores'
-import { completeLinuxDoOAuthRegistration } from '@/api/auth'
+import {
+ completeLinuxDoOAuthRegistration,
+ exchangePendingOAuthCompletion,
+ type OAuthAdoptionDecision,
+ type PendingOAuthExchangeResponse
+} from '@/api/auth'
const route = useRoute()
const router = useRouter()
@@ -85,11 +152,16 @@ const errorMessage = ref('')
// Invitation code flow state
const needsInvitation = ref(false)
-const pendingOAuthToken = ref('')
const invitationCode = ref('')
const isSubmitting = ref(false)
const invitationError = ref('')
const redirectTo = ref('/dashboard')
+const adoptionRequired = ref(false)
+const suggestedDisplayName = ref('')
+const suggestedAvatarUrl = ref('')
+const adoptDisplayName = ref(true)
+const adoptAvatar = ref(true)
+const needsAdoptionConfirmation = ref(false)
function parseFragmentParams(): URLSearchParams {
const raw = typeof window !== 'undefined' ? window.location.hash : ''
@@ -106,6 +178,54 @@ function sanitizeRedirectPath(path: string | null | undefined): string {
return path
}
+function currentAdoptionDecision(): OAuthAdoptionDecision {
+ return {
+ adoptDisplayName: adoptDisplayName.value,
+ adoptAvatar: adoptAvatar.value
+ }
+}
+
+function applyAdoptionSuggestionState(completion: {
+ adoption_required?: boolean
+ suggested_display_name?: string
+ suggested_avatar_url?: string
+}) {
+ adoptionRequired.value = completion.adoption_required === true
+ suggestedDisplayName.value = completion.suggested_display_name || ''
+ suggestedAvatarUrl.value = completion.suggested_avatar_url || ''
+
+ if (!suggestedDisplayName.value) {
+ adoptDisplayName.value = false
+ }
+ if (!suggestedAvatarUrl.value) {
+ adoptAvatar.value = false
+ }
+}
+
+function hasSuggestedProfile(completion: {
+ suggested_display_name?: string
+ suggested_avatar_url?: string
+}): boolean {
+ return Boolean(completion.suggested_display_name || completion.suggested_avatar_url)
+}
+
+async function finalizeLogin(completion: PendingOAuthExchangeResponse, redirect: string) {
+ if (!completion.access_token) {
+ throw new Error(t('auth.linuxdo.callbackMissingToken'))
+ }
+
+ if (completion.refresh_token) {
+ localStorage.setItem('refresh_token', completion.refresh_token)
+ }
+ if (completion.expires_in) {
+ localStorage.setItem('token_expires_at', String(Date.now() + completion.expires_in * 1000))
+ }
+
+ await authStore.setToken(completion.access_token)
+ appStore.showSuccess(t('auth.loginSuccess'))
+ await router.replace(redirect)
+}
+
async function handleSubmitInvitation() {
invitationError.value = ''
if (!invitationCode.value.trim()) return
@@ -113,8 +233,8 @@ async function handleSubmitInvitation() {
isSubmitting.value = true
try {
const tokenData = await completeLinuxDoOAuthRegistration(
- pendingOAuthToken.value,
- invitationCode.value.trim()
+ invitationCode.value.trim(),
+ currentAdoptionDecision()
)
if (tokenData.refresh_token) {
localStorage.setItem('refresh_token', tokenData.refresh_token)
@@ -134,63 +254,65 @@ async function handleSubmitInvitation() {
}
}
+async function handleContinueLogin() {
+ isSubmitting.value = true
+ try {
+ const completion = await exchangePendingOAuthCompletion(currentAdoptionDecision())
+ await finalizeLogin(completion, redirectTo.value)
+ } catch (e: unknown) {
+ const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } }
+ errorMessage.value =
+ err.response?.data?.detail ||
+ err.response?.data?.message ||
+ err.message ||
+ t('auth.loginFailed')
+ appStore.showError(errorMessage.value)
+ needsAdoptionConfirmation.value = false
+ } finally {
+ isSubmitting.value = false
+ }
+}
+
onMounted(async () => {
const params = parseFragmentParams()
-
- const token = params.get('access_token') || ''
- const refreshToken = params.get('refresh_token') || ''
- const expiresInStr = params.get('expires_in') || ''
- const redirect = sanitizeRedirectPath(
- params.get('redirect') || (route.query.redirect as string | undefined) || '/dashboard'
- )
const error = params.get('error')
const errorDesc = params.get('error_description') || params.get('error_message') || ''
if (error) {
- if (error === 'invitation_required') {
- pendingOAuthToken.value = params.get('pending_oauth_token') || ''
- redirectTo.value = sanitizeRedirectPath(params.get('redirect'))
- if (!pendingOAuthToken.value) {
- errorMessage.value = t('auth.linuxdo.invalidPendingToken')
- appStore.showError(errorMessage.value)
- isProcessing.value = false
- return
- }
- needsInvitation.value = true
- isProcessing.value = false
- return
- }
errorMessage.value = errorDesc || error
appStore.showError(errorMessage.value)
isProcessing.value = false
return
}
- if (!token) {
- errorMessage.value = t('auth.linuxdo.callbackMissingToken')
- appStore.showError(errorMessage.value)
- isProcessing.value = false
- return
- }
-
try {
- // Store refresh token and expires_at (convert to timestamp) if provided
- if (refreshToken) {
- localStorage.setItem('refresh_token', refreshToken)
+ const completion = await exchangePendingOAuthCompletion()
+ const redirect = sanitizeRedirectPath(
+ completion.redirect || (route.query.redirect as string | undefined) || '/dashboard'
+ )
+ applyAdoptionSuggestionState(completion)
+ redirectTo.value = redirect
+
+ if (completion.error === 'invitation_required') {
+ needsInvitation.value = true
+ isProcessing.value = false
+ return
}
- if (expiresInStr) {
- const expiresIn = parseInt(expiresInStr, 10)
- if (!isNaN(expiresIn)) {
- localStorage.setItem('token_expires_at', String(Date.now() + expiresIn * 1000))
- }
+
+ if (adoptionRequired.value && hasSuggestedProfile(completion)) {
+ needsAdoptionConfirmation.value = true
+ isProcessing.value = false
+ return
}
- await authStore.setToken(token)
- appStore.showSuccess(t('auth.loginSuccess'))
- await router.replace(redirect)
+ await finalizeLogin(completion, redirect)
} catch (e: unknown) {
- const err = e as { message?: string; response?: { data?: { detail?: string } } }
- errorMessage.value = err.response?.data?.detail || err.message || t('auth.loginFailed')
+ const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } }
+ errorMessage.value =
+ err.response?.data?.detail ||
+ err.response?.data?.message ||
+ err.message ||
+ t('auth.loginFailed')
appStore.showError(errorMessage.value)
isProcessing.value = false
}
@@ -209,4 +331,3 @@ onMounted(async () => {
transform: translateY(-8px);
}
-
diff --git a/frontend/src/views/auth/LoginView.vue b/frontend/src/views/auth/LoginView.vue
index 70b64e3f..fa4ac34c 100644
--- a/frontend/src/views/auth/LoginView.vue
+++ b/frontend/src/views/auth/LoginView.vue
@@ -11,12 +11,17 @@
-
+
+
(false)
const turnstileEnabled = ref(false)
const turnstileSiteKey = ref('')
const linuxdoOAuthEnabled = ref(false)
+const wechatOAuthEnabled = ref(false)
const backendModeEnabled = ref(false)
const oidcOAuthEnabled = ref(false)
const oidcOAuthProviderName = ref('OIDC')
@@ -267,6 +274,7 @@ onMounted(async () => {
turnstileEnabled.value = settings.turnstile_enabled
turnstileSiteKey.value = settings.turnstile_site_key || ''
linuxdoOAuthEnabled.value = settings.linuxdo_oauth_enabled
+ wechatOAuthEnabled.value = settings.wechat_oauth_enabled
backendModeEnabled.value = settings.backend_mode_enabled
oidcOAuthEnabled.value = settings.oidc_oauth_enabled
oidcOAuthProviderName.value = settings.oidc_oauth_provider_name || 'OIDC'
diff --git a/frontend/src/views/auth/OidcCallbackView.vue b/frontend/src/views/auth/OidcCallbackView.vue
index a6cb6c12..55f8af6e 100644
--- a/frontend/src/views/auth/OidcCallbackView.vue
+++ b/frontend/src/views/auth/OidcCallbackView.vue
@@ -15,36 +15,99 @@
-
-
- {{ t('auth.oidc.invitationRequired', { providerName }) }}
-
-
-
+
+
+
+
+
+ Use {{ providerName }} profile details
+
+
+ Choose whether to apply the nickname or avatar from {{ providerName }} to this
+ account.
+
+
+
+
+
+
+
+ Use display name
+
+
+ {{ suggestedDisplayName }}
+
+
+
+
+
+
+
+
+
+ Use avatar
+
+
+ {{ suggestedAvatarUrl }}
+
+
+
+
-
-
- {{ invitationError }}
+
+
+
+ {{ t('auth.oidc.invitationRequired', { providerName }) }}
-
-
- {{
- isSubmitting
- ? t('auth.oidc.completing')
- : t('auth.oidc.completeRegistration')
- }}
-
+
+
+
+
+
+ {{ invitationError }}
+
+
+
+ {{
+ isSubmitting
+ ? t('auth.oidc.completing')
+ : t('auth.oidc.completeRegistration')
+ }}
+
+
+
+
+
+ Review the {{ providerName }} profile details before continuing.
+
+
+ {{ isSubmitting ? t('common.processing') : 'Continue' }}
+
+
@@ -81,7 +144,10 @@ import Icon from '@/components/icons/Icon.vue'
import { useAuthStore, useAppStore } from '@/stores'
import {
completeOIDCOAuthRegistration,
- getPublicSettings
+ exchangePendingOAuthCompletion,
+ getPublicSettings,
+ type OAuthAdoptionDecision,
+ type PendingOAuthExchangeResponse
} from '@/api/auth'
const route = useRoute()
@@ -95,12 +161,17 @@ const isProcessing = ref(true)
const errorMessage = ref('')
const needsInvitation = ref(false)
-const pendingOAuthToken = ref('')
const invitationCode = ref('')
const isSubmitting = ref(false)
const invitationError = ref('')
const redirectTo = ref('/dashboard')
const providerName = ref('OIDC')
+const adoptionRequired = ref(false)
+const suggestedDisplayName = ref('')
+const suggestedAvatarUrl = ref('')
+const adoptDisplayName = ref(true)
+const adoptAvatar = ref(true)
+const needsAdoptionConfirmation = ref(false)
function parseFragmentParams(): URLSearchParams {
const raw = typeof window !== 'undefined' ? window.location.hash : ''
@@ -129,6 +200,54 @@ async function loadProviderName() {
}
}
+function currentAdoptionDecision(): OAuthAdoptionDecision {
+ return {
+ adoptDisplayName: adoptDisplayName.value,
+ adoptAvatar: adoptAvatar.value
+ }
+}
+
+function applyAdoptionSuggestionState(completion: {
+ adoption_required?: boolean
+ suggested_display_name?: string
+ suggested_avatar_url?: string
+}) {
+ adoptionRequired.value = completion.adoption_required === true
+ suggestedDisplayName.value = completion.suggested_display_name || ''
+ suggestedAvatarUrl.value = completion.suggested_avatar_url || ''
+
+ if (!suggestedDisplayName.value) {
+ adoptDisplayName.value = false
+ }
+ if (!suggestedAvatarUrl.value) {
+ adoptAvatar.value = false
+ }
+}
+
+function hasSuggestedProfile(completion: {
+ suggested_display_name?: string
+ suggested_avatar_url?: string
+}): boolean {
+ return Boolean(completion.suggested_display_name || completion.suggested_avatar_url)
+}
+
+async function finalizeLogin(completion: PendingOAuthExchangeResponse, redirect: string) {
+ if (!completion.access_token) {
+ throw new Error(t('auth.oidc.callbackMissingToken'))
+ }
+
+ if (completion.refresh_token) {
+ localStorage.setItem('refresh_token', completion.refresh_token)
+ }
+ if (completion.expires_in) {
+ localStorage.setItem('token_expires_at', String(Date.now() + completion.expires_in * 1000))
+ }
+
+ await authStore.setToken(completion.access_token)
+ appStore.showSuccess(t('auth.loginSuccess'))
+ await router.replace(redirect)
+}
+
async function handleSubmitInvitation() {
invitationError.value = ''
if (!invitationCode.value.trim()) return
@@ -136,8 +255,8 @@ async function handleSubmitInvitation() {
isSubmitting.value = true
try {
const tokenData = await completeOIDCOAuthRegistration(
- pendingOAuthToken.value,
- invitationCode.value.trim()
+ invitationCode.value.trim(),
+ currentAdoptionDecision()
)
if (tokenData.refresh_token) {
localStorage.setItem('refresh_token', tokenData.refresh_token)
@@ -157,63 +276,67 @@ async function handleSubmitInvitation() {
}
}
+async function handleContinueLogin() {
+ isSubmitting.value = true
+ try {
+ const completion = await exchangePendingOAuthCompletion(currentAdoptionDecision())
+ await finalizeLogin(completion, redirectTo.value)
+ } catch (e: unknown) {
+ const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } }
+ errorMessage.value =
+ err.response?.data?.detail ||
+ err.response?.data?.message ||
+ err.message ||
+ t('auth.loginFailed')
+ appStore.showError(errorMessage.value)
+ needsAdoptionConfirmation.value = false
+ } finally {
+ isSubmitting.value = false
+ }
+}
+
onMounted(async () => {
void loadProviderName()
const params = parseFragmentParams()
- const token = params.get('access_token') || ''
- const refreshToken = params.get('refresh_token') || ''
- const expiresInStr = params.get('expires_in') || ''
- const redirect = sanitizeRedirectPath(
- params.get('redirect') || (route.query.redirect as string | undefined) || '/dashboard'
- )
const error = params.get('error')
const errorDesc = params.get('error_description') || params.get('error_message') || ''
if (error) {
- if (error === 'invitation_required') {
- pendingOAuthToken.value = params.get('pending_oauth_token') || ''
- redirectTo.value = sanitizeRedirectPath(params.get('redirect'))
- if (!pendingOAuthToken.value) {
- errorMessage.value = t('auth.oidc.invalidPendingToken')
- appStore.showError(errorMessage.value)
- isProcessing.value = false
- return
- }
- needsInvitation.value = true
- isProcessing.value = false
- return
- }
errorMessage.value = errorDesc || error
appStore.showError(errorMessage.value)
isProcessing.value = false
return
}
- if (!token) {
- errorMessage.value = t('auth.oidc.callbackMissingToken')
- appStore.showError(errorMessage.value)
- isProcessing.value = false
- return
- }
-
try {
- if (refreshToken) {
- localStorage.setItem('refresh_token', refreshToken)
+ const completion = await exchangePendingOAuthCompletion()
+ const redirect = sanitizeRedirectPath(
+ completion.redirect || (route.query.redirect as string | undefined) || '/dashboard'
+ )
+ applyAdoptionSuggestionState(completion)
+ redirectTo.value = redirect
+
+ if (completion.error === 'invitation_required') {
+ needsInvitation.value = true
+ isProcessing.value = false
+ return
}
- if (expiresInStr) {
- const expiresIn = parseInt(expiresInStr, 10)
- if (!isNaN(expiresIn)) {
- localStorage.setItem('token_expires_at', String(Date.now() + expiresIn * 1000))
- }
+
+ if (adoptionRequired.value && hasSuggestedProfile(completion)) {
+ needsAdoptionConfirmation.value = true
+ isProcessing.value = false
+ return
}
- await authStore.setToken(token)
- appStore.showSuccess(t('auth.loginSuccess'))
- await router.replace(redirect)
+ await finalizeLogin(completion, redirect)
} catch (e: unknown) {
- const err = e as { message?: string; response?: { data?: { detail?: string } } }
- errorMessage.value = err.response?.data?.detail || err.message || t('auth.loginFailed')
+ const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } }
+ errorMessage.value =
+ err.response?.data?.detail ||
+ err.response?.data?.message ||
+ err.message ||
+ t('auth.loginFailed')
appStore.showError(errorMessage.value)
isProcessing.value = false
}
diff --git a/frontend/src/views/auth/RegisterView.vue b/frontend/src/views/auth/RegisterView.vue
index bc8b8dce..378f9d8a 100644
--- a/frontend/src/views/auth/RegisterView.vue
+++ b/frontend/src/views/auth/RegisterView.vue
@@ -11,12 +11,17 @@
-
+
+
(false)
const turnstileSiteKey = ref('')
const siteName = ref('Sub2API')
const linuxdoOAuthEnabled = ref(false)
+const wechatOAuthEnabled = ref(false)
const oidcOAuthEnabled = ref(false)
const oidcOAuthProviderName = ref('OIDC')
const registrationEmailSuffixWhitelist = ref([])
@@ -397,6 +404,7 @@ onMounted(async () => {
turnstileSiteKey.value = settings.turnstile_site_key || ''
siteName.value = settings.site_name || 'Sub2API'
linuxdoOAuthEnabled.value = settings.linuxdo_oauth_enabled
+ wechatOAuthEnabled.value = settings.wechat_oauth_enabled
oidcOAuthEnabled.value = settings.oidc_oauth_enabled
oidcOAuthProviderName.value = settings.oidc_oauth_provider_name || 'OIDC'
registrationEmailSuffixWhitelist.value = normalizeRegistrationEmailSuffixWhitelist(
diff --git a/frontend/src/views/auth/WechatCallbackView.vue b/frontend/src/views/auth/WechatCallbackView.vue
new file mode 100644
index 00000000..407b395b
--- /dev/null
+++ b/frontend/src/views/auth/WechatCallbackView.vue
@@ -0,0 +1,361 @@
+
+
+
+
+
+ {{ t('auth.oidc.callbackTitle', { providerName }) }}
+
+
+ {{
+ isProcessing
+ ? t('auth.oidc.callbackProcessing', { providerName })
+ : t('auth.oidc.callbackHint')
+ }}
+
+
+
+
+
+
+
+
+
+ Use {{ providerName }} profile details
+
+
+ Choose whether to apply the nickname or avatar from {{ providerName }} to this account.
+
+
+
+
+
+
+
+ Use display name
+
+
+ {{ suggestedDisplayName }}
+
+
+
+
+
+
+
+
+
+ Use avatar
+
+
+ {{ suggestedAvatarUrl }}
+
+
+
+
+
+
+
+
+ {{ t('auth.oidc.invitationRequired', { providerName }) }}
+
+
+
+
+
+
+ {{ invitationError }}
+
+
+
+ {{
+ isSubmitting
+ ? t('auth.oidc.completing')
+ : t('auth.oidc.completeRegistration')
+ }}
+
+
+
+
+
+ Review the {{ providerName }} profile details before continuing.
+
+
+ {{ isSubmitting ? t('common.processing') : 'Continue' }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {{ errorMessage }}
+
+
+ {{ t('auth.oidc.backToLogin') }}
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts b/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts
new file mode 100644
index 00000000..60a40474
--- /dev/null
+++ b/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts
@@ -0,0 +1,180 @@
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+import { flushPromises, mount } from '@vue/test-utils'
+
+import LinuxDoCallbackView from '../LinuxDoCallbackView.vue'
+
+const replace = vi.fn()
+const showSuccess = vi.fn()
+const showError = vi.fn()
+const setToken = vi.fn()
+const exchangePendingOAuthCompletion = vi.fn()
+const completeLinuxDoOAuthRegistration = vi.fn()
+
+vi.mock('vue-router', () => ({
+ useRoute: () => ({
+ query: {}
+ }),
+ useRouter: () => ({
+ replace
+ })
+}))
+
+vi.mock('vue-i18n', async () => {
+ const actual = await vi.importActual('vue-i18n')
+ return {
+ ...actual,
+ useI18n: () => ({
+ t: (key: string) => key
+ })
+ }
+})
+
+vi.mock('@/stores', () => ({
+ useAuthStore: () => ({
+ setToken
+ }),
+ useAppStore: () => ({
+ showSuccess,
+ showError
+ })
+}))
+
+vi.mock('@/api/auth', () => ({
+ exchangePendingOAuthCompletion: (...args: any[]) => exchangePendingOAuthCompletion(...args),
+ completeLinuxDoOAuthRegistration: (...args: any[]) => completeLinuxDoOAuthRegistration(...args)
+}))
+
+describe('LinuxDoCallbackView', () => {
+ beforeEach(() => {
+ replace.mockReset()
+ showSuccess.mockReset()
+ showError.mockReset()
+ setToken.mockReset()
+ exchangePendingOAuthCompletion.mockReset()
+ completeLinuxDoOAuthRegistration.mockReset()
+ })
+
+ it('does not send adoption decisions during the initial exchange', async () => {
+ exchangePendingOAuthCompletion.mockResolvedValue({
+ access_token: 'access-token',
+ refresh_token: 'refresh-token',
+ expires_in: 3600,
+ redirect: '/dashboard',
+ adoption_required: true
+ })
+ setToken.mockResolvedValue({})
+
+ mount(LinuxDoCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false
+ }
+ }
+ })
+
+ await flushPromises()
+
+ expect(exchangePendingOAuthCompletion).toHaveBeenCalledTimes(1)
+ expect(exchangePendingOAuthCompletion).toHaveBeenCalledWith()
+ })
+
+ it('waits for explicit adoption confirmation before finishing a non-invitation login', async () => {
+ exchangePendingOAuthCompletion
+ .mockResolvedValueOnce({
+ redirect: '/dashboard',
+ adoption_required: true,
+ suggested_display_name: 'LinuxDo Nick',
+ suggested_avatar_url: 'https://cdn.example/linuxdo.png'
+ })
+ .mockResolvedValueOnce({
+ access_token: 'access-token',
+ refresh_token: 'refresh-token',
+ expires_in: 3600,
+ redirect: '/dashboard'
+ })
+ setToken.mockResolvedValue({})
+
+ const wrapper = mount(LinuxDoCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false
+ }
+ }
+ })
+
+ await flushPromises()
+
+ expect(wrapper.text()).toContain('LinuxDo Nick')
+ expect(setToken).not.toHaveBeenCalled()
+ expect(replace).not.toHaveBeenCalled()
+
+ const checkboxes = wrapper.findAll('input[type="checkbox"]')
+ await checkboxes[1].setValue(false)
+
+ const buttons = wrapper.findAll('button')
+ expect(buttons).toHaveLength(1)
+ await buttons[0].trigger('click')
+ await flushPromises()
+
+ expect(exchangePendingOAuthCompletion).toHaveBeenCalledTimes(2)
+ expect(exchangePendingOAuthCompletion).toHaveBeenNthCalledWith(1)
+ expect(exchangePendingOAuthCompletion).toHaveBeenNthCalledWith(2, {
+ adoptDisplayName: true,
+ adoptAvatar: false
+ })
+ expect(setToken).toHaveBeenCalledWith('access-token')
+ expect(replace).toHaveBeenCalledWith('/dashboard')
+ })
+
+ it('renders adoption choices for invitation flow and submits the selected values', async () => {
+ exchangePendingOAuthCompletion.mockResolvedValue({
+ error: 'invitation_required',
+ redirect: '/dashboard',
+ adoption_required: true,
+ suggested_display_name: 'LinuxDo Nick',
+ suggested_avatar_url: 'https://cdn.example/linuxdo.png'
+ })
+ completeLinuxDoOAuthRegistration.mockResolvedValue({
+ access_token: 'access-token',
+ refresh_token: 'refresh-token',
+ expires_in: 3600,
+ token_type: 'Bearer'
+ })
+ setToken.mockResolvedValue({})
+
+ const wrapper = mount(LinuxDoCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false
+ }
+ }
+ })
+
+ await flushPromises()
+
+ expect(wrapper.text()).toContain('LinuxDo Nick')
+ expect(exchangePendingOAuthCompletion).toHaveBeenCalledTimes(1)
+ expect(exchangePendingOAuthCompletion).toHaveBeenCalledWith()
+
+ const checkboxes = wrapper.findAll('input[type="checkbox"]')
+ expect(checkboxes).toHaveLength(2)
+
+ await checkboxes[0].setValue(false)
+ await wrapper.find('input[type="text"]').setValue('invite-code')
+ await wrapper.find('button').trigger('click')
+
+ expect(completeLinuxDoOAuthRegistration).toHaveBeenCalledWith('invite-code', {
+ adoptDisplayName: false,
+ adoptAvatar: true
+ })
+ })
+})
diff --git a/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts b/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts
new file mode 100644
index 00000000..299c0746
--- /dev/null
+++ b/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts
@@ -0,0 +1,191 @@
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+import { flushPromises, mount } from '@vue/test-utils'
+
+import OidcCallbackView from '../OidcCallbackView.vue'
+
+const replace = vi.fn()
+const showSuccess = vi.fn()
+const showError = vi.fn()
+const setToken = vi.fn()
+const exchangePendingOAuthCompletion = vi.fn()
+const completeOIDCOAuthRegistration = vi.fn()
+const getPublicSettings = vi.fn()
+
+vi.mock('vue-router', () => ({
+ useRoute: () => ({
+ query: {}
+ }),
+ useRouter: () => ({
+ replace
+ })
+}))
+
+vi.mock('vue-i18n', async () => {
+ const actual = await vi.importActual('vue-i18n')
+ return {
+ ...actual,
+ useI18n: () => ({
+ t: (key: string, params?: Record) => {
+ if (!params?.providerName) {
+ return key
+ }
+ return `${key}:${params.providerName}`
+ }
+ })
+ }
+})
+
+vi.mock('@/stores', () => ({
+ useAuthStore: () => ({
+ setToken
+ }),
+ useAppStore: () => ({
+ showSuccess,
+ showError
+ })
+}))
+
+vi.mock('@/api/auth', () => ({
+ exchangePendingOAuthCompletion: (...args: any[]) => exchangePendingOAuthCompletion(...args),
+ completeOIDCOAuthRegistration: (...args: any[]) => completeOIDCOAuthRegistration(...args),
+ getPublicSettings: (...args: any[]) => getPublicSettings(...args)
+}))
+
+describe('OidcCallbackView', () => {
+ beforeEach(() => {
+ replace.mockReset()
+ showSuccess.mockReset()
+ showError.mockReset()
+ setToken.mockReset()
+ exchangePendingOAuthCompletion.mockReset()
+ completeOIDCOAuthRegistration.mockReset()
+ getPublicSettings.mockReset()
+ getPublicSettings.mockResolvedValue({
+ oidc_oauth_provider_name: 'ExampleID'
+ })
+ })
+
+ it('does not send adoption decisions during the initial exchange', async () => {
+ exchangePendingOAuthCompletion.mockResolvedValue({
+ access_token: 'access-token',
+ refresh_token: 'refresh-token',
+ expires_in: 3600,
+ redirect: '/dashboard',
+ adoption_required: true
+ })
+ setToken.mockResolvedValue({})
+
+ mount(OidcCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false
+ }
+ }
+ })
+
+ await flushPromises()
+
+ expect(exchangePendingOAuthCompletion).toHaveBeenCalledTimes(1)
+ expect(exchangePendingOAuthCompletion).toHaveBeenCalledWith()
+ })
+
+ it('waits for explicit adoption confirmation before finishing a non-invitation login', async () => {
+ exchangePendingOAuthCompletion
+ .mockResolvedValueOnce({
+ redirect: '/dashboard',
+ adoption_required: true,
+ suggested_display_name: 'OIDC Nick',
+ suggested_avatar_url: 'https://cdn.example/oidc.png'
+ })
+ .mockResolvedValueOnce({
+ access_token: 'access-token',
+ refresh_token: 'refresh-token',
+ expires_in: 3600,
+ redirect: '/dashboard'
+ })
+ setToken.mockResolvedValue({})
+
+ const wrapper = mount(OidcCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false
+ }
+ }
+ })
+
+ await flushPromises()
+
+ expect(wrapper.text()).toContain('OIDC Nick')
+ expect(setToken).not.toHaveBeenCalled()
+ expect(replace).not.toHaveBeenCalled()
+
+ const checkboxes = wrapper.findAll('input[type="checkbox"]')
+ await checkboxes[0].setValue(false)
+
+ const buttons = wrapper.findAll('button')
+ expect(buttons).toHaveLength(1)
+ await buttons[0].trigger('click')
+ await flushPromises()
+
+ expect(exchangePendingOAuthCompletion).toHaveBeenCalledTimes(2)
+ expect(exchangePendingOAuthCompletion).toHaveBeenNthCalledWith(1)
+ expect(exchangePendingOAuthCompletion).toHaveBeenNthCalledWith(2, {
+ adoptDisplayName: false,
+ adoptAvatar: true
+ })
+ expect(setToken).toHaveBeenCalledWith('access-token')
+ expect(replace).toHaveBeenCalledWith('/dashboard')
+ })
+
+ it('renders adoption choices for invitation flow and submits the selected values', async () => {
+ exchangePendingOAuthCompletion.mockResolvedValue({
+ error: 'invitation_required',
+ redirect: '/dashboard',
+ adoption_required: true,
+ suggested_display_name: 'OIDC Nick',
+ suggested_avatar_url: 'https://cdn.example/oidc.png'
+ })
+ completeOIDCOAuthRegistration.mockResolvedValue({
+ access_token: 'access-token',
+ refresh_token: 'refresh-token',
+ expires_in: 3600,
+ token_type: 'Bearer'
+ })
+ setToken.mockResolvedValue({})
+
+ const wrapper = mount(OidcCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false
+ }
+ }
+ })
+
+ await flushPromises()
+
+ expect(wrapper.text()).toContain('OIDC Nick')
+ expect(exchangePendingOAuthCompletion).toHaveBeenCalledTimes(1)
+ expect(exchangePendingOAuthCompletion).toHaveBeenCalledWith()
+
+ const checkboxes = wrapper.findAll('input[type="checkbox"]')
+ expect(checkboxes).toHaveLength(2)
+
+ await checkboxes[1].setValue(false)
+ await wrapper.find('input[type="text"]').setValue('invite-code')
+ await wrapper.find('button').trigger('click')
+
+ expect(completeOIDCOAuthRegistration).toHaveBeenCalledWith('invite-code', {
+ adoptDisplayName: true,
+ adoptAvatar: false
+ })
+ })
+})
diff --git a/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts
new file mode 100644
index 00000000..a9e2ada2
--- /dev/null
+++ b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts
@@ -0,0 +1,241 @@
+import { flushPromises, mount } from '@vue/test-utils'
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+import WechatCallbackView from '@/views/auth/WechatCallbackView.vue'
+
+const {
+ postMock,
+ replaceMock,
+ setTokenMock,
+ showSuccessMock,
+ showErrorMock,
+ routeState,
+} = vi.hoisted(() => ({
+ postMock: vi.fn(),
+ replaceMock: vi.fn(),
+ setTokenMock: vi.fn(),
+ showSuccessMock: vi.fn(),
+ showErrorMock: vi.fn(),
+ routeState: {
+ query: {} as Record,
+ },
+}))
+
+vi.mock('vue-router', () => ({
+ useRoute: () => routeState,
+ useRouter: () => ({
+ replace: replaceMock,
+ }),
+}))
+
+vi.mock('vue-i18n', () => ({
+ createI18n: () => ({
+ global: {
+ t: (key: string) => key,
+ },
+ }),
+ useI18n: () => ({
+ t: (key: string, params?: Record) => {
+ if (key === 'auth.oidc.callbackTitle') {
+ return `Signing you in with ${params?.providerName ?? ''}`.trim()
+ }
+ if (key === 'auth.oidc.callbackProcessing') {
+ return `Completing login with ${params?.providerName ?? ''}`.trim()
+ }
+ if (key === 'auth.oidc.invitationRequired') {
+ return `${params?.providerName ?? ''} invitation required`.trim()
+ }
+ if (key === 'auth.oidc.completeRegistration') {
+ return 'Complete registration'
+ }
+ if (key === 'auth.oidc.completing') {
+ return 'Completing'
+ }
+ if (key === 'auth.oidc.backToLogin') {
+ return 'Back to login'
+ }
+ if (key === 'auth.invitationCodePlaceholder') {
+ return 'Invitation code'
+ }
+ if (key === 'auth.loginSuccess') {
+ return 'Login success'
+ }
+ if (key === 'auth.loginFailed') {
+ return 'Login failed'
+ }
+ if (key === 'auth.oidc.callbackHint') {
+ return 'Callback hint'
+ }
+ if (key === 'auth.oidc.callbackMissingToken') {
+ return 'Missing login token'
+ }
+ if (key === 'auth.oidc.completeRegistrationFailed') {
+ return 'Complete registration failed'
+ }
+ return key
+ },
+ }),
+}))
+
+vi.mock('@/stores', () => ({
+ useAuthStore: () => ({
+ setToken: setTokenMock,
+ }),
+ useAppStore: () => ({
+ showSuccess: showSuccessMock,
+ showError: showErrorMock,
+ }),
+}))
+
+vi.mock('@/api/client', () => ({
+ apiClient: {
+ post: postMock,
+ },
+}))
+
+describe('WechatCallbackView', () => {
+ beforeEach(() => {
+ postMock.mockReset()
+ replaceMock.mockReset()
+ setTokenMock.mockReset()
+ showSuccessMock.mockReset()
+ showErrorMock.mockReset()
+ routeState.query = {}
+ localStorage.clear()
+ })
+
+ it('does not send adoption decisions during the initial exchange', async () => {
+ postMock.mockResolvedValueOnce({
+ data: {
+ access_token: 'access-token',
+ refresh_token: 'refresh-token',
+ expires_in: 3600,
+ redirect: '/dashboard',
+ adoption_required: true,
+ },
+ })
+ setTokenMock.mockResolvedValue({})
+
+ mount(WechatCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false,
+ },
+ },
+ })
+
+ await flushPromises()
+
+ expect(postMock).toHaveBeenCalledWith('/auth/oauth/pending/exchange', {})
+ expect(postMock).toHaveBeenCalledTimes(1)
+ })
+
+ it('waits for explicit adoption confirmation before finishing a non-invitation login', async () => {
+ postMock
+ .mockResolvedValueOnce({
+ data: {
+ redirect: '/dashboard',
+ adoption_required: true,
+ suggested_display_name: 'WeChat Nick',
+ suggested_avatar_url: 'https://cdn.example/wechat.png',
+ },
+ })
+ .mockResolvedValueOnce({
+ data: {
+ access_token: 'wechat-access-token',
+ refresh_token: 'wechat-refresh-token',
+ expires_in: 3600,
+ token_type: 'Bearer',
+ redirect: '/dashboard',
+ },
+ })
+ setTokenMock.mockResolvedValue({})
+
+ const wrapper = mount(WechatCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false,
+ },
+ },
+ })
+
+ await flushPromises()
+
+ expect(wrapper.text()).toContain('WeChat Nick')
+ expect(setTokenMock).not.toHaveBeenCalled()
+ expect(replaceMock).not.toHaveBeenCalled()
+
+ const checkboxes = wrapper.findAll('input[type="checkbox"]')
+ expect(checkboxes).toHaveLength(2)
+ await checkboxes[1].setValue(false)
+
+ const buttons = wrapper.findAll('button')
+ expect(buttons).toHaveLength(1)
+ await buttons[0].trigger('click')
+ await flushPromises()
+
+ expect(postMock).toHaveBeenNthCalledWith(1, '/auth/oauth/pending/exchange', {})
+ expect(postMock).toHaveBeenNthCalledWith(2, '/auth/oauth/pending/exchange', {
+ adopt_display_name: true,
+ adopt_avatar: false,
+ })
+ expect(setTokenMock).toHaveBeenCalledWith('wechat-access-token')
+ expect(replaceMock).toHaveBeenCalledWith('/dashboard')
+ expect(localStorage.getItem('refresh_token')).toBe('wechat-refresh-token')
+ })
+
+ it('renders adoption choices for invitation flow and submits the selected values', async () => {
+ postMock
+ .mockResolvedValueOnce({
+ data: {
+ error: 'invitation_required',
+ redirect: '/subscriptions',
+ adoption_required: true,
+ suggested_display_name: 'WeChat Nick',
+ suggested_avatar_url: 'https://cdn.example/wechat.png',
+ },
+ })
+ .mockResolvedValueOnce({
+ data: {
+ access_token: 'wechat-invite-token',
+ refresh_token: 'wechat-invite-refresh',
+ expires_in: 600,
+ token_type: 'Bearer',
+ },
+ })
+
+ const wrapper = mount(WechatCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false,
+ },
+ },
+ })
+
+ await flushPromises()
+
+ expect(wrapper.text()).toContain('WeChat Nick')
+ const checkboxes = wrapper.findAll('input[type="checkbox"]')
+ expect(checkboxes).toHaveLength(2)
+ await checkboxes[0].setValue(false)
+ await wrapper.get('input[type="text"]').setValue(' INVITE-CODE ')
+ await wrapper.get('button').trigger('click')
+ await flushPromises()
+
+ expect(postMock).toHaveBeenNthCalledWith(2, '/auth/oauth/wechat/complete-registration', {
+ invitation_code: 'INVITE-CODE',
+ adopt_display_name: false,
+ adopt_avatar: true,
+ })
+ expect(setTokenMock).toHaveBeenCalledWith('wechat-invite-token')
+ expect(replaceMock).toHaveBeenCalledWith('/subscriptions')
+ })
+})
diff --git a/frontend/src/views/user/PaymentView.vue b/frontend/src/views/user/PaymentView.vue
index e91df5da..838f3000 100644
--- a/frontend/src/views/user/PaymentView.vue
+++ b/frontend/src/views/user/PaymentView.vue
@@ -23,20 +23,7 @@
:order-type="paymentState.orderType"
@done="onPaymentDone"
@success="onPaymentSuccess"
- />
-
-
-
@@ -265,7 +252,7 @@
diff --git a/frontend/src/components/user/profile/ProfileInfoCard.vue b/frontend/src/components/user/profile/ProfileInfoCard.vue
index b6f6022d..e82ae229 100644
--- a/frontend/src/components/user/profile/ProfileInfoCard.vue
+++ b/frontend/src/components/user/profile/ProfileInfoCard.vue
@@ -4,11 +4,16 @@
class="border-b border-gray-100 bg-gradient-to-r from-primary-500/10 to-primary-600/5 px-6 py-5 dark:border-dark-700 dark:from-primary-500/20 dark:to-primary-600/10"
>
-
- {{ user?.email?.charAt(0).toUpperCase() || 'U' }}
+
+
{{ avatarInitial }}
@@ -41,18 +46,163 @@
{{ user.username }}
+
+
+
+
+ {{ hint.text }}
+
+
+
+
diff --git a/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts b/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts
new file mode 100644
index 00000000..1c9531e3
--- /dev/null
+++ b/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts
@@ -0,0 +1,120 @@
+import { mount } from '@vue/test-utils'
+import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
+import ProfileIdentityBindingsSection from '@/components/user/profile/ProfileIdentityBindingsSection.vue'
+import type { User } from '@/types'
+
+const routeState = vi.hoisted(() => ({
+ fullPath: '/profile',
+}))
+
+const locationState = vi.hoisted(() => ({
+ current: { href: 'http://localhost/profile' } as { href: string },
+}))
+
+vi.mock('vue-router', () => ({
+ useRoute: () => routeState,
+}))
+
+vi.mock('vue-i18n', async (importOriginal) => {
+ const actual = await importOriginal
()
+ return {
+ ...actual,
+ useI18n: () => ({
+ t: (key: string, params?: Record) => {
+ if (key === 'profile.authBindings.title') return 'Connected sign-in methods'
+ if (key === 'profile.authBindings.description') return 'Manage bound providers'
+ if (key === 'profile.authBindings.status.bound') return 'Bound'
+ if (key === 'profile.authBindings.status.notBound') return 'Not bound'
+ if (key === 'profile.authBindings.providers.email') return 'Email'
+ if (key === 'profile.authBindings.providers.linuxdo') return 'LinuxDo'
+ if (key === 'profile.authBindings.providers.wechat') return 'WeChat'
+ if (key === 'profile.authBindings.providers.oidc') return params?.providerName || 'OIDC'
+ if (key === 'profile.authBindings.bindAction') return `Bind ${params?.providerName || ''}`.trim()
+ return key
+ },
+ }),
+ }
+})
+
+function createUser(overrides: Partial = {}): User {
+ return {
+ id: 7,
+ username: 'alice',
+ email: 'alice@example.com',
+ role: 'user',
+ balance: 10,
+ concurrency: 2,
+ status: 'active',
+ allowed_groups: null,
+ balance_notify_enabled: true,
+ balance_notify_threshold: null,
+ balance_notify_extra_emails: [],
+ created_at: '2026-04-20T00:00:00Z',
+ updated_at: '2026-04-20T00:00:00Z',
+ ...overrides,
+ }
+}
+
+describe('ProfileIdentityBindingsSection', () => {
+ beforeEach(() => {
+ routeState.fullPath = '/profile'
+ locationState.current = { href: 'http://localhost/profile' }
+ Object.defineProperty(window, 'location', {
+ configurable: true,
+ value: locationState.current,
+ })
+ Object.defineProperty(window.navigator, 'userAgent', {
+ configurable: true,
+ value: 'Mozilla/5.0',
+ })
+ })
+
+ afterEach(() => {
+ vi.unstubAllGlobals()
+ })
+
+ it('renders provider binding states and provider-specific bind actions', () => {
+ const wrapper = mount(ProfileIdentityBindingsSection, {
+ props: {
+ user: createUser({
+ auth_bindings: {
+ email: { bound: true },
+ linuxdo: { bound: true },
+ oidc: { bound: false },
+ wechat: false,
+ },
+ }),
+ linuxdoEnabled: true,
+ oidcEnabled: true,
+ oidcProviderName: 'ExampleID',
+ wechatEnabled: true,
+ },
+ })
+
+ expect(wrapper.get('[data-testid="profile-binding-email-status"]').text()).toBe('Bound')
+ expect(wrapper.get('[data-testid="profile-binding-linuxdo-status"]').text()).toBe('Bound')
+ expect(wrapper.get('[data-testid="profile-binding-oidc-status"]').text()).toBe('Not bound')
+ expect(wrapper.get('[data-testid="profile-binding-oidc-action"]').text()).toBe(
+ 'Bind ExampleID'
+ )
+ expect(wrapper.get('[data-testid="profile-binding-wechat-action"]').text()).toBe('Bind WeChat')
+ })
+
+ it('starts the WeChat bind flow for the current profile page', async () => {
+ const wrapper = mount(ProfileIdentityBindingsSection, {
+ props: {
+ user: createUser(),
+ linuxdoEnabled: false,
+ oidcEnabled: false,
+ wechatEnabled: true,
+ },
+ })
+
+ await wrapper.get('[data-testid="profile-binding-wechat-action"]').trigger('click')
+
+ expect(locationState.current.href).toContain('/api/v1/auth/oauth/wechat/start?')
+ expect(locationState.current.href).toContain('mode=open')
+ expect(locationState.current.href).toContain('intent=bind_current_user')
+ expect(locationState.current.href).toContain('redirect=%2Fprofile')
+ })
+})
diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts
index 684c196f..7d058a74 100644
--- a/frontend/src/i18n/locales/en.ts
+++ b/frontend/src/i18n/locales/en.ts
@@ -940,6 +940,26 @@ export default {
maxEmailsReached: 'Maximum number of notification emails reached',
unverified: 'Unverified',
verified: 'Verified',
+ },
+ authBindings: {
+ title: 'Connected Sign-In Methods',
+ description: 'View current bindings and connect another provider to this account.',
+ bindAction: 'Bind {providerName}',
+ bindSuccess: 'Account linked successfully',
+ status: {
+ bound: 'Bound',
+ notBound: 'Not bound',
+ },
+ providers: {
+ email: 'Email',
+ linuxdo: 'LinuxDo',
+ oidc: '{providerName}',
+ wechat: 'WeChat',
+ },
+ source: {
+ avatar: 'Avatar is currently synced from {providerName}',
+ username: 'Nickname is currently synced from {providerName}',
+ },
}
},
diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts
index 2a4c69a5..6dd74334 100644
--- a/frontend/src/i18n/locales/zh.ts
+++ b/frontend/src/i18n/locales/zh.ts
@@ -944,6 +944,26 @@ export default {
maxEmailsReached: '已达到通知邮箱数量上限',
unverified: '未验证',
verified: '已验证',
+ },
+ authBindings: {
+ title: '登录方式绑定',
+ description: '查看当前绑定状态,并将更多第三方登录方式关联到这个账号。',
+ bindAction: '绑定 {providerName}',
+ bindSuccess: '账号绑定成功',
+ status: {
+ bound: '已绑定',
+ notBound: '未绑定',
+ },
+ providers: {
+ email: '邮箱',
+ linuxdo: 'LinuxDo',
+ oidc: '{providerName}',
+ wechat: '微信',
+ },
+ source: {
+ avatar: '头像当前来自 {providerName}',
+ username: '昵称当前来自 {providerName}',
+ },
}
},
diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts
index 9c9722a9..a19d6c26 100644
--- a/frontend/src/types/index.ts
+++ b/frontend/src/types/index.ts
@@ -34,10 +34,47 @@ export interface NotifyEmailEntry {
// ==================== User & Auth Types ====================
+export type UserAuthProvider = 'email' | 'linuxdo' | 'oidc' | 'wechat'
+
+export interface UserAuthBindingStatus {
+ bound?: boolean
+ provider?: UserAuthProvider | string
+ provider_key?: string | null
+ provider_subject?: string | null
+ issuer?: string | null
+ label?: string | null
+ provider_label?: string | null
+ metadata?: Record
+}
+
+export interface UserProfileSourceContext {
+ provider?: UserAuthProvider | string
+ source?: string | null
+ label?: string | null
+ provider_label?: string | null
+}
+
export interface User {
id: number
username: string
email: string
+ avatar_url?: string | null
+ avatar_source?: string | UserProfileSourceContext | null
+ username_source?: string | UserProfileSourceContext | null
+ display_name_source?: string | UserProfileSourceContext | null
+ nickname_source?: string | UserProfileSourceContext | null
+ profile_sources?: {
+ avatar?: string | UserProfileSourceContext | null
+ username?: string | UserProfileSourceContext | null
+ display_name?: string | UserProfileSourceContext | null
+ nickname?: string | UserProfileSourceContext | null
+ }
+ auth_bindings?: Partial>
+ identity_bindings?: Partial>
+ email_bound?: boolean
+ linuxdo_bound?: boolean
+ oidc_bound?: boolean
+ wechat_bound?: boolean
role: 'admin' | 'user' // User role for authorization
balance: number // User balance for API usage
concurrency: number // Allowed concurrent requests
diff --git a/frontend/src/views/auth/LinuxDoCallbackView.vue b/frontend/src/views/auth/LinuxDoCallbackView.vue
index 0a125def..6dc8f242 100644
--- a/frontend/src/views/auth/LinuxDoCallbackView.vue
+++ b/frontend/src/views/auth/LinuxDoCallbackView.vue
@@ -136,6 +136,9 @@ import { useAuthStore, useAppStore } from '@/stores'
import {
completeLinuxDoOAuthRegistration,
exchangePendingOAuthCompletion,
+ getOAuthCompletionKind,
+ isOAuthLoginCompletion,
+ persistOAuthTokenContext,
type OAuthAdoptionDecision,
type PendingOAuthExchangeResponse
} from '@/api/auth'
@@ -162,6 +165,7 @@ const suggestedAvatarUrl = ref('')
const adoptDisplayName = ref(true)
const adoptAvatar = ref(true)
const needsAdoptionConfirmation = ref(false)
+const bindSuccessMessage = t('profile.authBindings.bindSuccess')
function parseFragmentParams(): URLSearchParams {
const raw = typeof window !== 'undefined' ? window.location.hash : ''
@@ -209,18 +213,19 @@ function hasSuggestedProfile(completion: {
return Boolean(completion.suggested_display_name || completion.suggested_avatar_url)
}
-async function finalizeLogin(completion: PendingOAuthExchangeResponse, redirect: string) {
- if (!completion.access_token) {
- throw new Error(t('auth.linuxdo.callbackMissingToken'))
+async function finalizeCompletion(completion: PendingOAuthExchangeResponse, redirect: string) {
+ if (getOAuthCompletionKind(completion) === 'bind') {
+ const bindRedirect = sanitizeRedirectPath(completion.redirect || '/profile')
+ appStore.showSuccess(bindSuccessMessage)
+ await router.replace(bindRedirect)
+ return
}
- if (completion.refresh_token) {
- localStorage.setItem('refresh_token', completion.refresh_token)
- }
- if (completion.expires_in) {
- localStorage.setItem('token_expires_at', String(Date.now() + completion.expires_in * 1000))
+ if (!isOAuthLoginCompletion(completion)) {
+ throw new Error(t('auth.linuxdo.callbackMissingToken'))
}
+ persistOAuthTokenContext(completion)
await authStore.setToken(completion.access_token)
appStore.showSuccess(t('auth.loginSuccess'))
await router.replace(redirect)
@@ -236,12 +241,7 @@ async function handleSubmitInvitation() {
invitationCode.value.trim(),
currentAdoptionDecision()
)
- if (tokenData.refresh_token) {
- localStorage.setItem('refresh_token', tokenData.refresh_token)
- }
- if (tokenData.expires_in) {
- localStorage.setItem('token_expires_at', String(Date.now() + tokenData.expires_in * 1000))
- }
+ persistOAuthTokenContext(tokenData)
await authStore.setToken(tokenData.access_token)
appStore.showSuccess(t('auth.loginSuccess'))
await router.replace(redirectTo.value)
@@ -258,7 +258,7 @@ async function handleContinueLogin() {
isSubmitting.value = true
try {
const completion = await exchangePendingOAuthCompletion(currentAdoptionDecision())
- await finalizeLogin(completion, redirectTo.value)
+ await finalizeCompletion(completion, redirectTo.value)
} catch (e: unknown) {
const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } }
errorMessage.value =
@@ -305,7 +305,7 @@ onMounted(async () => {
return
}
- await finalizeLogin(completion, redirect)
+ await finalizeCompletion(completion, redirect)
} catch (e: unknown) {
const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } }
errorMessage.value =
diff --git a/frontend/src/views/auth/OidcCallbackView.vue b/frontend/src/views/auth/OidcCallbackView.vue
index 55f8af6e..83304226 100644
--- a/frontend/src/views/auth/OidcCallbackView.vue
+++ b/frontend/src/views/auth/OidcCallbackView.vue
@@ -145,7 +145,10 @@ import { useAuthStore, useAppStore } from '@/stores'
import {
completeOIDCOAuthRegistration,
exchangePendingOAuthCompletion,
+ getOAuthCompletionKind,
getPublicSettings,
+ isOAuthLoginCompletion,
+ persistOAuthTokenContext,
type OAuthAdoptionDecision,
type PendingOAuthExchangeResponse
} from '@/api/auth'
@@ -172,6 +175,7 @@ const suggestedAvatarUrl = ref('')
const adoptDisplayName = ref(true)
const adoptAvatar = ref(true)
const needsAdoptionConfirmation = ref(false)
+const bindSuccessMessage = t('profile.authBindings.bindSuccess')
function parseFragmentParams(): URLSearchParams {
const raw = typeof window !== 'undefined' ? window.location.hash : ''
@@ -231,18 +235,19 @@ function hasSuggestedProfile(completion: {
return Boolean(completion.suggested_display_name || completion.suggested_avatar_url)
}
-async function finalizeLogin(completion: PendingOAuthExchangeResponse, redirect: string) {
- if (!completion.access_token) {
- throw new Error(t('auth.oidc.callbackMissingToken'))
+async function finalizeCompletion(completion: PendingOAuthExchangeResponse, redirect: string) {
+ if (getOAuthCompletionKind(completion) === 'bind') {
+ const bindRedirect = sanitizeRedirectPath(completion.redirect || '/profile')
+ appStore.showSuccess(bindSuccessMessage)
+ await router.replace(bindRedirect)
+ return
}
- if (completion.refresh_token) {
- localStorage.setItem('refresh_token', completion.refresh_token)
- }
- if (completion.expires_in) {
- localStorage.setItem('token_expires_at', String(Date.now() + completion.expires_in * 1000))
+ if (!isOAuthLoginCompletion(completion)) {
+ throw new Error(t('auth.oidc.callbackMissingToken'))
}
+ persistOAuthTokenContext(completion)
await authStore.setToken(completion.access_token)
appStore.showSuccess(t('auth.loginSuccess'))
await router.replace(redirect)
@@ -258,12 +263,7 @@ async function handleSubmitInvitation() {
invitationCode.value.trim(),
currentAdoptionDecision()
)
- if (tokenData.refresh_token) {
- localStorage.setItem('refresh_token', tokenData.refresh_token)
- }
- if (tokenData.expires_in) {
- localStorage.setItem('token_expires_at', String(Date.now() + tokenData.expires_in * 1000))
- }
+ persistOAuthTokenContext(tokenData)
await authStore.setToken(tokenData.access_token)
appStore.showSuccess(t('auth.loginSuccess'))
await router.replace(redirectTo.value)
@@ -280,7 +280,7 @@ async function handleContinueLogin() {
isSubmitting.value = true
try {
const completion = await exchangePendingOAuthCompletion(currentAdoptionDecision())
- await finalizeLogin(completion, redirectTo.value)
+ await finalizeCompletion(completion, redirectTo.value)
} catch (e: unknown) {
const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } }
errorMessage.value =
@@ -329,7 +329,7 @@ onMounted(async () => {
return
}
- await finalizeLogin(completion, redirect)
+ await finalizeCompletion(completion, redirect)
} catch (e: unknown) {
const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } }
errorMessage.value =
diff --git a/frontend/src/views/auth/WechatCallbackView.vue b/frontend/src/views/auth/WechatCallbackView.vue
index 407b395b..ac0ede4c 100644
--- a/frontend/src/views/auth/WechatCallbackView.vue
+++ b/frontend/src/views/auth/WechatCallbackView.vue
@@ -140,27 +140,16 @@ import { useRoute, useRouter } from 'vue-router'
import { useI18n } from 'vue-i18n'
import { AuthLayout } from '@/components/layout'
import Icon from '@/components/icons/Icon.vue'
-import { apiClient } from '@/api/client'
import { useAuthStore, useAppStore } from '@/stores'
-
-interface OAuthTokenResponse {
- access_token: string
- refresh_token: string
- expires_in: number
- token_type: string
-}
-
-interface PendingOAuthExchangeResponse {
- access_token?: string
- refresh_token?: string
- expires_in?: number
- token_type?: string
- redirect?: string
- error?: string
- adoption_required?: boolean
- suggested_display_name?: string
- suggested_avatar_url?: string
-}
+import {
+ completeWeChatOAuthRegistration,
+ exchangePendingOAuthCompletion,
+ getOAuthCompletionKind,
+ isOAuthLoginCompletion,
+ persistOAuthTokenContext,
+ type OAuthAdoptionDecision,
+ type PendingOAuthExchangeResponse
+} from '@/api/auth'
const route = useRoute()
const router = useRouter()
@@ -182,6 +171,7 @@ const suggestedAvatarUrl = ref('')
const adoptDisplayName = ref(true)
const adoptAvatar = ref(true)
const needsAdoptionConfirmation = ref(false)
+const bindSuccessMessage = t('profile.authBindings.bindSuccess')
const providerName = 'WeChat'
@@ -200,10 +190,10 @@ function sanitizeRedirectPath(path: string | null | undefined): string {
return path
}
-function currentAdoptionDecision(): Record {
+function currentAdoptionDecision(): OAuthAdoptionDecision {
return {
- adopt_display_name: adoptDisplayName.value,
- adopt_avatar: adoptAvatar.value,
+ adoptDisplayName: adoptDisplayName.value,
+ adoptAvatar: adoptAvatar.value
}
}
@@ -224,49 +214,35 @@ function hasSuggestedProfile(completion: PendingOAuthExchangeResponse): boolean
return Boolean(completion.suggested_display_name || completion.suggested_avatar_url)
}
-async function exchangePendingOAuthCompletion(): Promise {
- const { data } = await apiClient.post('/auth/oauth/pending/exchange', {})
- return data
-}
-
-async function finalizeLogin(completion: PendingOAuthExchangeResponse, redirect: string) {
- if (!completion.access_token) {
- throw new Error(t('auth.oidc.callbackMissingToken'))
+async function finalizeCompletion(completion: PendingOAuthExchangeResponse, redirect: string) {
+ if (getOAuthCompletionKind(completion) === 'bind') {
+ const bindRedirect = sanitizeRedirectPath(completion.redirect || '/profile')
+ appStore.showSuccess(bindSuccessMessage)
+ await router.replace(bindRedirect)
+ return
}
- if (completion.refresh_token) {
- localStorage.setItem('refresh_token', completion.refresh_token)
- }
- if (completion.expires_in) {
- localStorage.setItem('token_expires_at', String(Date.now() + completion.expires_in * 1000))
+ if (!isOAuthLoginCompletion(completion)) {
+ throw new Error(t('auth.oidc.callbackMissingToken'))
}
+ persistOAuthTokenContext(completion)
await authStore.setToken(completion.access_token)
appStore.showSuccess(t('auth.loginSuccess'))
await router.replace(redirect)
}
-async function completeWeChatOAuthRegistration(invitation: string): Promise {
- const { data } = await apiClient.post('/auth/oauth/wechat/complete-registration', {
- invitation_code: invitation,
- ...currentAdoptionDecision(),
- })
- return data
-}
-
async function handleSubmitInvitation() {
invitationError.value = ''
if (!invitationCode.value.trim()) return
isSubmitting.value = true
try {
- const tokenData = await completeWeChatOAuthRegistration(invitationCode.value.trim())
- if (tokenData.refresh_token) {
- localStorage.setItem('refresh_token', tokenData.refresh_token)
- }
- if (tokenData.expires_in) {
- localStorage.setItem('token_expires_at', String(Date.now() + tokenData.expires_in * 1000))
- }
+ const tokenData = await completeWeChatOAuthRegistration(
+ invitationCode.value.trim(),
+ currentAdoptionDecision()
+ )
+ persistOAuthTokenContext(tokenData)
await authStore.setToken(tokenData.access_token)
appStore.showSuccess(t('auth.loginSuccess'))
await router.replace(redirectTo.value)
@@ -282,11 +258,8 @@ async function handleSubmitInvitation() {
async function handleContinueLogin() {
isSubmitting.value = true
try {
- const { data } = await apiClient.post(
- '/auth/oauth/pending/exchange',
- currentAdoptionDecision()
- )
- await finalizeLogin(data, redirectTo.value)
+ const completion = await exchangePendingOAuthCompletion(currentAdoptionDecision())
+ await finalizeCompletion(completion, redirectTo.value)
} catch (e: unknown) {
const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } }
errorMessage.value =
@@ -333,7 +306,7 @@ onMounted(async () => {
return
}
- await finalizeLogin(completion, redirect)
+ await finalizeCompletion(completion, redirect)
} catch (e: unknown) {
const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } }
errorMessage.value =
diff --git a/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts b/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts
index 60a40474..7ffdcd19 100644
--- a/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts
+++ b/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts
@@ -39,10 +39,14 @@ vi.mock('@/stores', () => ({
})
}))
-vi.mock('@/api/auth', () => ({
- exchangePendingOAuthCompletion: (...args: any[]) => exchangePendingOAuthCompletion(...args),
- completeLinuxDoOAuthRegistration: (...args: any[]) => completeLinuxDoOAuthRegistration(...args)
-}))
+vi.mock('@/api/auth', async () => {
+ const actual = await vi.importActual('@/api/auth')
+ return {
+ ...actual,
+ exchangePendingOAuthCompletion: (...args: any[]) => exchangePendingOAuthCompletion(...args),
+ completeLinuxDoOAuthRegistration: (...args: any[]) => completeLinuxDoOAuthRegistration(...args)
+ }
+})
describe('LinuxDoCallbackView', () => {
beforeEach(() => {
@@ -132,6 +136,64 @@ describe('LinuxDoCallbackView', () => {
expect(replace).toHaveBeenCalledWith('/dashboard')
})
+ it('treats a completion without token as bind success and returns to profile', async () => {
+ exchangePendingOAuthCompletion.mockResolvedValue({})
+
+ mount(LinuxDoCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false
+ }
+ }
+ })
+
+ await flushPromises()
+
+ expect(setToken).not.toHaveBeenCalled()
+ expect(showSuccess).toHaveBeenCalledWith('profile.authBindings.bindSuccess')
+ expect(replace).toHaveBeenCalledWith('/profile')
+ })
+
+ it('supports bind completion after adoption confirmation', async () => {
+ exchangePendingOAuthCompletion
+ .mockResolvedValueOnce({
+ redirect: '/dashboard',
+ adoption_required: true,
+ suggested_display_name: 'LinuxDo Nick',
+ suggested_avatar_url: 'https://cdn.example/linuxdo.png'
+ })
+ .mockResolvedValueOnce({
+ redirect: '/profile/security'
+ })
+
+ const wrapper = mount(LinuxDoCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false
+ }
+ }
+ })
+
+ await flushPromises()
+
+ await wrapper.findAll('button')[0].trigger('click')
+ await flushPromises()
+
+ expect(exchangePendingOAuthCompletion).toHaveBeenNthCalledWith(2, {
+ adoptDisplayName: true,
+ adoptAvatar: true
+ })
+ expect(setToken).not.toHaveBeenCalled()
+ expect(showSuccess).toHaveBeenCalledWith('profile.authBindings.bindSuccess')
+ expect(replace).toHaveBeenCalledWith('/profile/security')
+ })
+
it('renders adoption choices for invitation flow and submits the selected values', async () => {
exchangePendingOAuthCompletion.mockResolvedValue({
error: 'invitation_required',
diff --git a/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts b/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts
index 299c0746..f8de79f2 100644
--- a/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts
+++ b/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts
@@ -45,11 +45,15 @@ vi.mock('@/stores', () => ({
})
}))
-vi.mock('@/api/auth', () => ({
- exchangePendingOAuthCompletion: (...args: any[]) => exchangePendingOAuthCompletion(...args),
- completeOIDCOAuthRegistration: (...args: any[]) => completeOIDCOAuthRegistration(...args),
- getPublicSettings: (...args: any[]) => getPublicSettings(...args)
-}))
+vi.mock('@/api/auth', async () => {
+ const actual = await vi.importActual('@/api/auth')
+ return {
+ ...actual,
+ exchangePendingOAuthCompletion: (...args: any[]) => exchangePendingOAuthCompletion(...args),
+ completeOIDCOAuthRegistration: (...args: any[]) => completeOIDCOAuthRegistration(...args),
+ getPublicSettings: (...args: any[]) => getPublicSettings(...args)
+ }
+})
describe('OidcCallbackView', () => {
beforeEach(() => {
@@ -143,6 +147,43 @@ describe('OidcCallbackView', () => {
expect(replace).toHaveBeenCalledWith('/dashboard')
})
+ it('supports bind completion after adoption confirmation', async () => {
+ exchangePendingOAuthCompletion
+ .mockResolvedValueOnce({
+ redirect: '/dashboard',
+ adoption_required: true,
+ suggested_display_name: 'OIDC Nick',
+ suggested_avatar_url: 'https://cdn.example/oidc.png'
+ })
+ .mockResolvedValueOnce({
+ redirect: '/profile'
+ })
+
+ const wrapper = mount(OidcCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false
+ }
+ }
+ })
+
+ await flushPromises()
+
+ await wrapper.findAll('button')[0].trigger('click')
+ await flushPromises()
+
+ expect(exchangePendingOAuthCompletion).toHaveBeenNthCalledWith(2, {
+ adoptDisplayName: true,
+ adoptAvatar: true
+ })
+ expect(setToken).not.toHaveBeenCalled()
+ expect(showSuccess).toHaveBeenCalledWith('profile.authBindings.bindSuccess')
+ expect(replace).toHaveBeenCalledWith('/profile')
+ })
+
it('renders adoption choices for invitation flow and submits the selected values', async () => {
exchangePendingOAuthCompletion.mockResolvedValue({
error: 'invitation_required',
diff --git a/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts
index a9e2ada2..896bf15d 100644
--- a/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts
+++ b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts
@@ -3,14 +3,16 @@ import { beforeEach, describe, expect, it, vi } from 'vitest'
import WechatCallbackView from '@/views/auth/WechatCallbackView.vue'
const {
- postMock,
+ exchangePendingOAuthCompletionMock,
+ completeWeChatOAuthRegistrationMock,
replaceMock,
setTokenMock,
showSuccessMock,
showErrorMock,
routeState,
} = vi.hoisted(() => ({
- postMock: vi.fn(),
+ exchangePendingOAuthCompletionMock: vi.fn(),
+ completeWeChatOAuthRegistrationMock: vi.fn(),
replaceMock: vi.fn(),
setTokenMock: vi.fn(),
showSuccessMock: vi.fn(),
@@ -86,15 +88,19 @@ vi.mock('@/stores', () => ({
}),
}))
-vi.mock('@/api/client', () => ({
- apiClient: {
- post: postMock,
- },
-}))
+vi.mock('@/api/auth', async () => {
+ const actual = await vi.importActual('@/api/auth')
+ return {
+ ...actual,
+ exchangePendingOAuthCompletion: (...args: any[]) => exchangePendingOAuthCompletionMock(...args),
+ completeWeChatOAuthRegistration: (...args: any[]) => completeWeChatOAuthRegistrationMock(...args),
+ }
+})
describe('WechatCallbackView', () => {
beforeEach(() => {
- postMock.mockReset()
+ exchangePendingOAuthCompletionMock.mockReset()
+ completeWeChatOAuthRegistrationMock.mockReset()
replaceMock.mockReset()
setTokenMock.mockReset()
showSuccessMock.mockReset()
@@ -104,14 +110,12 @@ describe('WechatCallbackView', () => {
})
it('does not send adoption decisions during the initial exchange', async () => {
- postMock.mockResolvedValueOnce({
- data: {
- access_token: 'access-token',
- refresh_token: 'refresh-token',
- expires_in: 3600,
- redirect: '/dashboard',
- adoption_required: true,
- },
+ exchangePendingOAuthCompletionMock.mockResolvedValue({
+ access_token: 'access-token',
+ refresh_token: 'refresh-token',
+ expires_in: 3600,
+ redirect: '/dashboard',
+ adoption_required: true,
})
setTokenMock.mockResolvedValue({})
@@ -128,28 +132,24 @@ describe('WechatCallbackView', () => {
await flushPromises()
- expect(postMock).toHaveBeenCalledWith('/auth/oauth/pending/exchange', {})
- expect(postMock).toHaveBeenCalledTimes(1)
+ expect(exchangePendingOAuthCompletionMock).toHaveBeenCalledWith()
+ expect(exchangePendingOAuthCompletionMock).toHaveBeenCalledTimes(1)
})
it('waits for explicit adoption confirmation before finishing a non-invitation login', async () => {
- postMock
+ exchangePendingOAuthCompletionMock
.mockResolvedValueOnce({
- data: {
- redirect: '/dashboard',
- adoption_required: true,
- suggested_display_name: 'WeChat Nick',
- suggested_avatar_url: 'https://cdn.example/wechat.png',
- },
+ redirect: '/dashboard',
+ adoption_required: true,
+ suggested_display_name: 'WeChat Nick',
+ suggested_avatar_url: 'https://cdn.example/wechat.png',
})
.mockResolvedValueOnce({
- data: {
- access_token: 'wechat-access-token',
- refresh_token: 'wechat-refresh-token',
- expires_in: 3600,
- token_type: 'Bearer',
- redirect: '/dashboard',
- },
+ access_token: 'wechat-access-token',
+ refresh_token: 'wechat-refresh-token',
+ expires_in: 3600,
+ token_type: 'Bearer',
+ redirect: '/dashboard',
})
setTokenMock.mockResolvedValue({})
@@ -179,34 +179,26 @@ describe('WechatCallbackView', () => {
await buttons[0].trigger('click')
await flushPromises()
- expect(postMock).toHaveBeenNthCalledWith(1, '/auth/oauth/pending/exchange', {})
- expect(postMock).toHaveBeenNthCalledWith(2, '/auth/oauth/pending/exchange', {
- adopt_display_name: true,
- adopt_avatar: false,
+ expect(exchangePendingOAuthCompletionMock).toHaveBeenNthCalledWith(1)
+ expect(exchangePendingOAuthCompletionMock).toHaveBeenNthCalledWith(2, {
+ adoptDisplayName: true,
+ adoptAvatar: false,
})
expect(setTokenMock).toHaveBeenCalledWith('wechat-access-token')
expect(replaceMock).toHaveBeenCalledWith('/dashboard')
expect(localStorage.getItem('refresh_token')).toBe('wechat-refresh-token')
})
- it('renders adoption choices for invitation flow and submits the selected values', async () => {
- postMock
+ it('supports bind completion after adoption confirmation', async () => {
+ exchangePendingOAuthCompletionMock
.mockResolvedValueOnce({
- data: {
- error: 'invitation_required',
- redirect: '/subscriptions',
- adoption_required: true,
- suggested_display_name: 'WeChat Nick',
- suggested_avatar_url: 'https://cdn.example/wechat.png',
- },
+ redirect: '/dashboard',
+ adoption_required: true,
+ suggested_display_name: 'WeChat Nick',
+ suggested_avatar_url: 'https://cdn.example/wechat.png',
})
.mockResolvedValueOnce({
- data: {
- access_token: 'wechat-invite-token',
- refresh_token: 'wechat-invite-refresh',
- expires_in: 600,
- token_type: 'Bearer',
- },
+ redirect: '/profile/connections',
})
const wrapper = mount(WechatCallbackView, {
@@ -222,6 +214,46 @@ describe('WechatCallbackView', () => {
await flushPromises()
+ await wrapper.findAll('button')[0].trigger('click')
+ await flushPromises()
+
+ expect(exchangePendingOAuthCompletionMock).toHaveBeenNthCalledWith(2, {
+ adoptDisplayName: true,
+ adoptAvatar: true,
+ })
+ expect(setTokenMock).not.toHaveBeenCalled()
+ expect(showSuccessMock).toHaveBeenCalledWith('profile.authBindings.bindSuccess')
+ expect(replaceMock).toHaveBeenCalledWith('/profile/connections')
+ })
+
+ it('renders adoption choices for invitation flow and submits the selected values', async () => {
+ exchangePendingOAuthCompletionMock.mockResolvedValue({
+ error: 'invitation_required',
+ redirect: '/subscriptions',
+ adoption_required: true,
+ suggested_display_name: 'WeChat Nick',
+ suggested_avatar_url: 'https://cdn.example/wechat.png',
+ })
+ completeWeChatOAuthRegistrationMock.mockResolvedValue({
+ access_token: 'wechat-invite-token',
+ refresh_token: 'wechat-invite-refresh',
+ expires_in: 600,
+ token_type: 'Bearer',
+ })
+
+ const wrapper = mount(WechatCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false,
+ },
+ },
+ })
+
+ await flushPromises()
+
expect(wrapper.text()).toContain('WeChat Nick')
const checkboxes = wrapper.findAll('input[type="checkbox"]')
expect(checkboxes).toHaveLength(2)
@@ -230,10 +262,9 @@ describe('WechatCallbackView', () => {
await wrapper.get('button').trigger('click')
await flushPromises()
- expect(postMock).toHaveBeenNthCalledWith(2, '/auth/oauth/wechat/complete-registration', {
- invitation_code: 'INVITE-CODE',
- adopt_display_name: false,
- adopt_avatar: true,
+ expect(completeWeChatOAuthRegistrationMock).toHaveBeenCalledWith('INVITE-CODE', {
+ adoptDisplayName: false,
+ adoptAvatar: true,
})
expect(setTokenMock).toHaveBeenCalledWith('wechat-invite-token')
expect(replaceMock).toHaveBeenCalledWith('/subscriptions')
diff --git a/frontend/src/views/user/ProfileView.vue b/frontend/src/views/user/ProfileView.vue
index e7418ebb..f7418be9 100644
--- a/frontend/src/views/user/ProfileView.vue
+++ b/frontend/src/views/user/ProfileView.vue
@@ -2,18 +2,53 @@
-
-
-
+
+
+
-
-
+
+
+
+
-
-
{{ t('common.contactSupport') }} {{ contactInfo }}
+
+
+
+
+
+ {{ t('common.contactSupport') }}
+
+
{{ contactInfo }}
+
+
+
+
@@ -29,26 +65,78 @@
\ No newline at end of file
+const formatCurrency = (value: number) => `$${value.toFixed(2)}`
+
--
GitLab
From 4e0e69154649def4f4149054e09ff03fc8a3e50a Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Mon, 20 Apr 2026 18:39:53 +0800
Subject: [PATCH 036/261] feat: apply auth source signup defaults
---
.../service/admin_service_apikey_test.go | 3 +
.../service/admin_service_delete_test.go | 51 ++++--
backend/internal/service/auth_service.go | 117 +++++++++----
.../service/auth_service_register_test.go | 159 +++++++++++++++++-
4 files changed, 283 insertions(+), 47 deletions(-)
diff --git a/backend/internal/service/admin_service_apikey_test.go b/backend/internal/service/admin_service_apikey_test.go
index b802a9c2..487fb5f1 100644
--- a/backend/internal/service/admin_service_apikey_test.go
+++ b/backend/internal/service/admin_service_apikey_test.go
@@ -79,6 +79,9 @@ func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *s
}
func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") }
+func (s *userRepoStubForGroupUpdate) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) {
+ panic("unexpected")
+}
func (s *userRepoStubForGroupUpdate) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
panic("unexpected")
}
diff --git a/backend/internal/service/admin_service_delete_test.go b/backend/internal/service/admin_service_delete_test.go
index 323286b0..ac1d8ee7 100644
--- a/backend/internal/service/admin_service_delete_test.go
+++ b/backend/internal/service/admin_service_delete_test.go
@@ -13,15 +13,18 @@ import (
)
type userRepoStub struct {
- user *User
- getErr error
- createErr error
- deleteErr error
- exists bool
- existsErr error
- nextID int64
- created []*User
- deletedIDs []int64
+ user *User
+ getErr error
+ createErr error
+ deleteErr error
+ exists bool
+ existsErr error
+ nextID int64
+ created []*User
+ updated []*User
+ deletedIDs []int64
+ usersByEmail map[string]*User
+ getByEmailErr error
}
func (s *userRepoStub) Create(ctx context.Context, user *User) error {
@@ -32,6 +35,11 @@ func (s *userRepoStub) Create(ctx context.Context, user *User) error {
user.ID = s.nextID
}
s.created = append(s.created, user)
+ if s.usersByEmail == nil {
+ s.usersByEmail = make(map[string]*User)
+ }
+ s.usersByEmail[user.Email] = user
+ s.user = user
return nil
}
@@ -46,7 +54,18 @@ func (s *userRepoStub) GetByID(ctx context.Context, id int64) (*User, error) {
}
func (s *userRepoStub) GetByEmail(ctx context.Context, email string) (*User, error) {
- panic("unexpected GetByEmail call")
+ if s.getByEmailErr != nil {
+ return nil, s.getByEmailErr
+ }
+ if s.usersByEmail != nil {
+ if user, ok := s.usersByEmail[email]; ok {
+ return user, nil
+ }
+ }
+ if s.user != nil && s.user.Email == email {
+ return s.user, nil
+ }
+ return nil, ErrUserNotFound
}
func (s *userRepoStub) GetFirstAdmin(ctx context.Context) (*User, error) {
@@ -54,7 +73,13 @@ func (s *userRepoStub) GetFirstAdmin(ctx context.Context) (*User, error) {
}
func (s *userRepoStub) Update(ctx context.Context, user *User) error {
- panic("unexpected Update call")
+ s.updated = append(s.updated, user)
+ if s.usersByEmail == nil {
+ s.usersByEmail = make(map[string]*User)
+ }
+ s.usersByEmail[user.Email] = user
+ s.user = user
+ return nil
}
func (s *userRepoStub) Delete(ctx context.Context, id int64) error {
@@ -113,6 +138,10 @@ func (s *userRepoStub) AddGroupToAllowedGroups(ctx context.Context, userID int64
panic("unexpected AddGroupToAllowedGroups call")
}
+func (s *userRepoStub) ListUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error) {
+ panic("unexpected ListUserAuthIdentities call")
+}
+
func (s *userRepoStub) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
panic("unexpected UpdateTotpSecret call")
}
diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go
index 962009ce..40753139 100644
--- a/backend/internal/service/auth_service.go
+++ b/backend/internal/service/auth_service.go
@@ -78,6 +78,12 @@ type DefaultSubscriptionAssigner interface {
AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error)
}
+type signupGrantPlan struct {
+ Balance float64
+ Concurrency int
+ Subscriptions []DefaultSubscriptionSetting
+}
+
// NewAuthService 创建认证服务实例
func NewAuthService(
entClient *dbent.Client,
@@ -187,21 +193,15 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
return "", nil, fmt.Errorf("hash password: %w", err)
}
- // 获取默认配置
- defaultBalance := s.cfg.Default.UserBalance
- defaultConcurrency := s.cfg.Default.UserConcurrency
- if s.settingService != nil {
- defaultBalance = s.settingService.GetDefaultBalance(ctx)
- defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx)
- }
+ grantPlan := s.resolveSignupGrantPlan(ctx, "email")
// 创建用户
user := &User{
Email: email,
PasswordHash: hashedPassword,
Role: RoleUser,
- Balance: defaultBalance,
- Concurrency: defaultConcurrency,
+ Balance: grantPlan.Balance,
+ Concurrency: grantPlan.Concurrency,
Status: StatusActive,
}
@@ -214,7 +214,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
return "", nil, ErrServiceUnavailable
}
s.postAuthUserBootstrap(ctx, user, "email", true)
- s.assignDefaultSubscriptions(ctx, user.ID)
+ s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
// 标记邀请码为已使用(如果使用了邀请码)
if invitationRedeemCode != nil {
@@ -479,21 +479,16 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
return "", nil, fmt.Errorf("hash password: %w", err)
}
- // 新用户默认值。
- defaultBalance := s.cfg.Default.UserBalance
- defaultConcurrency := s.cfg.Default.UserConcurrency
- if s.settingService != nil {
- defaultBalance = s.settingService.GetDefaultBalance(ctx)
- defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx)
- }
+ signupSource := inferLegacySignupSource(email)
+ grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
newUser := &User{
Email: email,
Username: username,
PasswordHash: hashedPassword,
Role: RoleUser,
- Balance: defaultBalance,
- Concurrency: defaultConcurrency,
+ Balance: grantPlan.Balance,
+ Concurrency: grantPlan.Concurrency,
Status: StatusActive,
}
@@ -511,8 +506,8 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
}
} else {
user = newUser
- s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true)
- s.assignDefaultSubscriptions(ctx, user.ID)
+ s.postAuthUserBootstrap(ctx, user, signupSource, true)
+ s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
}
} else {
logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err)
@@ -596,20 +591,16 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
return nil, nil, fmt.Errorf("hash password: %w", err)
}
- defaultBalance := s.cfg.Default.UserBalance
- defaultConcurrency := s.cfg.Default.UserConcurrency
- if s.settingService != nil {
- defaultBalance = s.settingService.GetDefaultBalance(ctx)
- defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx)
- }
+ signupSource := inferLegacySignupSource(email)
+ grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
newUser := &User{
Email: email,
Username: username,
PasswordHash: hashedPassword,
Role: RoleUser,
- Balance: defaultBalance,
- Concurrency: defaultConcurrency,
+ Balance: grantPlan.Balance,
+ Concurrency: grantPlan.Concurrency,
Status: StatusActive,
}
@@ -642,8 +633,8 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
return nil, nil, ErrServiceUnavailable
}
user = newUser
- s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true)
- s.assignDefaultSubscriptions(ctx, user.ID)
+ s.postAuthUserBootstrap(ctx, user, signupSource, true)
+ s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
}
} else {
if err := s.userRepo.Create(ctx, newUser); err != nil {
@@ -659,8 +650,8 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
}
} else {
user = newUser
- s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true)
- s.assignDefaultSubscriptions(ctx, user.ID)
+ s.postAuthUserBootstrap(ctx, user, signupSource, true)
+ s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
if invitationRedeemCode != nil {
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
return nil, nil, ErrInvitationCodeInvalid
@@ -694,22 +685,78 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
}
func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) {
+ if s.settingService == nil {
+ return
+ }
+ s.assignSubscriptions(ctx, userID, s.settingService.GetDefaultSubscriptions(ctx), "auto assigned by default user subscriptions setting")
+}
+
+func (s *AuthService) assignSubscriptions(ctx context.Context, userID int64, items []DefaultSubscriptionSetting, notes string) {
if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 {
return
}
- items := s.settingService.GetDefaultSubscriptions(ctx)
for _, item := range items {
if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{
UserID: userID,
GroupID: item.GroupID,
ValidityDays: item.ValidityDays,
- Notes: "auto assigned by default user subscriptions setting",
+ Notes: notes,
}); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to assign default subscription: user_id=%d group_id=%d err=%v", userID, item.GroupID, err)
}
}
}
+func (s *AuthService) resolveSignupGrantPlan(ctx context.Context, signupSource string) signupGrantPlan {
+ plan := signupGrantPlan{}
+ if s != nil && s.cfg != nil {
+ plan.Balance = s.cfg.Default.UserBalance
+ plan.Concurrency = s.cfg.Default.UserConcurrency
+ }
+ if s == nil || s.settingService == nil {
+ return plan
+ }
+
+ plan.Balance = s.settingService.GetDefaultBalance(ctx)
+ plan.Concurrency = s.settingService.GetDefaultConcurrency(ctx)
+ plan.Subscriptions = s.settingService.GetDefaultSubscriptions(ctx)
+
+ defaults, err := s.settingService.GetAuthSourceDefaultSettings(ctx)
+ if err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to load auth source signup defaults for %s: %v", signupSource, err)
+ return plan
+ }
+
+ providerDefaults, ok := authSourceSignupSettings(defaults, signupSource)
+ if !ok || !providerDefaults.GrantOnSignup {
+ return plan
+ }
+
+ plan.Balance = providerDefaults.Balance
+ plan.Concurrency = providerDefaults.Concurrency
+ plan.Subscriptions = providerDefaults.Subscriptions
+ return plan
+}
+
+func authSourceSignupSettings(defaults *AuthSourceDefaultSettings, signupSource string) (ProviderDefaultGrantSettings, bool) {
+ if defaults == nil {
+ return ProviderDefaultGrantSettings{}, false
+ }
+
+ switch strings.ToLower(strings.TrimSpace(signupSource)) {
+ case "email":
+ return defaults.Email, true
+ case "linuxdo":
+ return defaults.LinuxDo, true
+ case "oidc":
+ return defaults.OIDC, true
+ case "wechat":
+ return defaults.WeChat, true
+ default:
+ return ProviderDefaultGrantSettings{}, false
+ }
+}
+
func (s *AuthService) postAuthUserBootstrap(ctx context.Context, user *User, signupSource string, touchLogin bool) {
if user == nil || user.ID <= 0 {
return
diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go
index 103bafe7..901b3db3 100644
--- a/backend/internal/service/auth_service_register_test.go
+++ b/backend/internal/service/auth_service_register_test.go
@@ -37,7 +37,16 @@ func (s *settingRepoStub) Set(ctx context.Context, key, value string) error {
}
func (s *settingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
- panic("unexpected GetMultiple call")
+ if s.err != nil {
+ return nil, s.err
+ }
+ result := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if v, ok := s.values[key]; ok {
+ result[key] = v
+ }
+ }
+ return result, nil
}
func (s *settingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
@@ -62,6 +71,8 @@ type defaultSubscriptionAssignerStub struct {
err error
}
+type refreshTokenCacheStub struct{}
+
func (s *defaultSubscriptionAssignerStub) AssignOrExtendSubscription(_ context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) {
if input != nil {
s.calls = append(s.calls, *input)
@@ -72,6 +83,46 @@ func (s *defaultSubscriptionAssignerStub) AssignOrExtendSubscription(_ context.C
return &UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, false, nil
}
+func (s *refreshTokenCacheStub) StoreRefreshToken(context.Context, string, *RefreshTokenData, time.Duration) error {
+ return nil
+}
+
+func (s *refreshTokenCacheStub) GetRefreshToken(context.Context, string) (*RefreshTokenData, error) {
+ return nil, ErrRefreshTokenNotFound
+}
+
+func (s *refreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error {
+ return nil
+}
+
+func (s *refreshTokenCacheStub) DeleteUserRefreshTokens(context.Context, int64) error {
+ return nil
+}
+
+func (s *refreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error {
+ return nil
+}
+
+func (s *refreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error {
+ return nil
+}
+
+func (s *refreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error {
+ return nil
+}
+
+func (s *refreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) {
+ return nil, nil
+}
+
+func (s *refreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) {
+ return nil, nil
+}
+
+func (s *refreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) {
+ return false, nil
+}
+
func (s *emailCacheStub) GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error) {
if s.err != nil {
return nil, s.err
@@ -484,3 +535,109 @@ func TestAuthService_Register_AssignsDefaultSubscriptions(t *testing.T) {
require.Equal(t, int64(12), assigner.calls[1].GroupID)
require.Equal(t, 7, assigner.calls[1].ValidityDays)
}
+
+func TestAuthService_Register_UsesEmailAuthSourceDefaultsWhenGrantEnabled(t *testing.T) {
+ repo := &userRepoStub{nextID: 52}
+ assigner := &defaultSubscriptionAssignerStub{}
+ service := newAuthService(repo, map[string]string{
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyDefaultSubscriptions: `[{"group_id":91,"validity_days":3}]`,
+ SettingKeyAuthSourceDefaultEmailBalance: "12.5",
+ SettingKeyAuthSourceDefaultEmailConcurrency: "7",
+ SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
+ SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true",
+ }, nil)
+ service.defaultSubAssigner = assigner
+
+ _, user, err := service.Register(context.Background(), "email-defaults@test.com", "password")
+ require.NoError(t, err)
+ require.NotNil(t, user)
+ require.Equal(t, 12.5, user.Balance)
+ require.Equal(t, 7, user.Concurrency)
+ require.Len(t, assigner.calls, 1)
+ require.Equal(t, int64(11), assigner.calls[0].GroupID)
+ require.Equal(t, 30, assigner.calls[0].ValidityDays)
+}
+
+func TestAuthService_Register_GrantOnSignupFalseFallsBackToGlobalDefaults(t *testing.T) {
+ repo := &userRepoStub{nextID: 53}
+ assigner := &defaultSubscriptionAssignerStub{}
+ service := newAuthService(repo, map[string]string{
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyDefaultSubscriptions: `[{"group_id":31,"validity_days":5}]`,
+ SettingKeyAuthSourceDefaultEmailBalance: "99",
+ SettingKeyAuthSourceDefaultEmailConcurrency: "88",
+ SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":32,"validity_days":9}]`,
+ SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false",
+ }, nil)
+ service.defaultSubAssigner = assigner
+
+ _, user, err := service.Register(context.Background(), "email-global@test.com", "password")
+ require.NoError(t, err)
+ require.NotNil(t, user)
+ require.Equal(t, 3.5, user.Balance)
+ require.Equal(t, 2, user.Concurrency)
+ require.Len(t, assigner.calls, 1)
+ require.Equal(t, int64(31), assigner.calls[0].GroupID)
+ require.Equal(t, 5, assigner.calls[0].ValidityDays)
+}
+
+func TestAuthService_LoginOrRegisterOAuthWithTokenPair_UsesLinuxDoAuthSourceDefaultsOnSignup(t *testing.T) {
+ repo := &userRepoStub{nextID: 61}
+ assigner := &defaultSubscriptionAssignerStub{}
+ service := newAuthService(repo, map[string]string{
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyDefaultSubscriptions: `[{"group_id":81,"validity_days":1}]`,
+ SettingKeyAuthSourceDefaultLinuxDoBalance: "21.75",
+ SettingKeyAuthSourceDefaultLinuxDoConcurrency: "9",
+ SettingKeyAuthSourceDefaultLinuxDoSubscriptions: `[{"group_id":22,"validity_days":14}]`,
+ SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "true",
+ }, nil)
+ service.defaultSubAssigner = assigner
+ service.refreshTokenCache = &refreshTokenCacheStub{}
+
+ tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "linuxdo_user", "")
+ require.NoError(t, err)
+ require.NotNil(t, tokenPair)
+ require.NotNil(t, user)
+ require.Equal(t, int64(61), user.ID)
+ require.Equal(t, 21.75, user.Balance)
+ require.Equal(t, 9, user.Concurrency)
+ require.Len(t, repo.created, 1)
+ require.Len(t, assigner.calls, 1)
+ require.Equal(t, int64(22), assigner.calls[0].GroupID)
+ require.Equal(t, 14, assigner.calls[0].ValidityDays)
+}
+
+func TestAuthService_LoginOrRegisterOAuthWithTokenPair_ExistingUserDoesNotGrantAgain(t *testing.T) {
+ existing := &User{
+ ID: 88,
+ Email: "linuxdo-123@linuxdo-connect.invalid",
+ Username: "existing-linuxdo",
+ Role: RoleUser,
+ Status: StatusActive,
+ Balance: 4,
+ Concurrency: 1,
+ TokenVersion: 2,
+ }
+ repo := &userRepoStub{user: existing}
+ assigner := &defaultSubscriptionAssignerStub{}
+ service := newAuthService(repo, map[string]string{
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyAuthSourceDefaultLinuxDoBalance: "21.75",
+ SettingKeyAuthSourceDefaultLinuxDoConcurrency: "9",
+ SettingKeyAuthSourceDefaultLinuxDoSubscriptions: `[{"group_id":22,"validity_days":14}]`,
+ SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "true",
+ }, nil)
+ service.defaultSubAssigner = assigner
+ service.refreshTokenCache = &refreshTokenCacheStub{}
+
+ tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), existing.Email, "linuxdo_user", "")
+ require.NoError(t, err)
+ require.NotNil(t, tokenPair)
+ require.Equal(t, existing.ID, user.ID)
+ require.Equal(t, 4.0, user.Balance)
+ require.Equal(t, 1, user.Concurrency)
+ require.Empty(t, repo.created)
+ require.Empty(t, assigner.calls)
+}
--
GitLab
From 0353c3870f938f5d57d4060232dfe9e3bc4a01ff Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Mon, 20 Apr 2026 18:40:34 +0800
Subject: [PATCH 037/261] test: update user service stubs for identity
summaries
---
.../service/billing_cache_service_singleflight_test.go | 4 ++++
backend/internal/service/user_service_test.go | 9 ++++++---
2 files changed, 10 insertions(+), 3 deletions(-)
diff --git a/backend/internal/service/billing_cache_service_singleflight_test.go b/backend/internal/service/billing_cache_service_singleflight_test.go
index 4a8b8f03..2ebc2f04 100644
--- a/backend/internal/service/billing_cache_service_singleflight_test.go
+++ b/backend/internal/service/billing_cache_service_singleflight_test.go
@@ -86,6 +86,10 @@ func (s *balanceLoadUserRepoStub) GetByID(ctx context.Context, id int64) (*User,
return &User{ID: id, Balance: s.balance}, nil
}
+func (s *balanceLoadUserRepoStub) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) {
+ return nil, nil
+}
+
func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) {
cache := &billingCacheMissStub{}
userRepo := &balanceLoadUserRepoStub{
diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go
index 7d63bb36..89f0362e 100644
--- a/backend/internal/service/user_service_test.go
+++ b/backend/internal/service/user_service_test.go
@@ -100,9 +100,12 @@ func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int
return 0, nil
}
func (m *mockUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil }
-func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
-func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil }
-func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil }
+func (m *mockUserRepo) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) {
+ return nil, nil
+}
+func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
+func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil }
+func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil }
func (m *mockUserRepo) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
return nil
}
--
GitLab
From d47580a144f083965b04e4612e4fd9ceb4779e35 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Mon, 20 Apr 2026 18:42:28 +0800
Subject: [PATCH 038/261] test: pin email signup defaults in register tests
---
backend/internal/service/auth_service_register_test.go | 8 +++++---
1 file changed, 5 insertions(+), 3 deletions(-)
diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go
index 901b3db3..e0dce982 100644
--- a/backend/internal/service/auth_service_register_test.go
+++ b/backend/internal/service/auth_service_register_test.go
@@ -373,7 +373,8 @@ func TestAuthService_Register_CreateEmailExistsRace(t *testing.T) {
func TestAuthService_Register_Success(t *testing.T) {
repo := &userRepoStub{nextID: 5}
service := newAuthService(repo, map[string]string{
- SettingKeyRegistrationEnabled: "true",
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false",
}, nil)
token, user, err := service.Register(context.Background(), "user@test.com", "password")
@@ -520,8 +521,9 @@ func TestAuthService_Register_AssignsDefaultSubscriptions(t *testing.T) {
repo := &userRepoStub{nextID: 42}
assigner := &defaultSubscriptionAssignerStub{}
service := newAuthService(repo, map[string]string{
- SettingKeyRegistrationEnabled: "true",
- SettingKeyDefaultSubscriptions: `[{"group_id":11,"validity_days":30},{"group_id":12,"validity_days":7}]`,
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyDefaultSubscriptions: `[{"group_id":11,"validity_days":30},{"group_id":12,"validity_days":7}]`,
+ SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false",
}, nil)
service.defaultSubAssigner = assigner
--
GitLab
From 6a75bd77e3cbf11a12a624a474a979f684be5fb6 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Mon, 20 Apr 2026 19:30:09 +0800
Subject: [PATCH 039/261] feat: add pending oauth email onboarding flow
---
.../internal/handler/auth_linuxdo_oauth.go | 125 +++--
.../handler/auth_oauth_pending_flow.go | 426 +++++++++++++++++-
.../handler/auth_oauth_pending_flow_test.go | 371 ++++++++++++++-
backend/internal/handler/auth_oidc_oauth.go | 137 +++---
backend/internal/handler/auth_wechat_oauth.go | 33 ++
backend/internal/handler/dto/settings.go | 1 +
backend/internal/handler/setting_handler.go | 1 +
.../handler/setting_handler_public_test.go | 83 ++++
backend/internal/server/routes/auth.go | 48 ++
.../internal/service/auth_oauth_email_flow.go | 151 +++++++
backend/internal/service/setting_service.go | 2 +
.../service/setting_service_public_test.go | 13 +
backend/internal/service/settings_view.go | 1 +
13 files changed, 1273 insertions(+), 119 deletions(-)
create mode 100644 backend/internal/handler/setting_handler_public_test.go
create mode 100644 backend/internal/service/auth_oauth_email_flow.go
diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go
index 835e5fd8..c4ecb8fa 100644
--- a/backend/internal/handler/auth_linuxdo_oauth.go
+++ b/backend/internal/handler/auth_linuxdo_oauth.go
@@ -243,6 +243,18 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
if subject != "" {
email = linuxDoSyntheticEmail(subject)
}
+ identityKey := service.PendingAuthIdentityKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: subject,
+ }
+ upstreamClaims := map[string]any{
+ "email": email,
+ "username": username,
+ "subject": subject,
+ "suggested_display_name": displayName,
+ "suggested_avatar_url": avatarURL,
+ }
if intent == oauthIntentBindCurrentUser {
targetUserID, err := h.readOAuthBindUserIDFromCookie(c, linuxDoOAuthBindUserCookieName)
if err != nil {
@@ -250,23 +262,13 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
return
}
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
- Intent: oauthIntentBindCurrentUser,
- Identity: service.PendingAuthIdentityKey{
- ProviderType: "linuxdo",
- ProviderKey: "linuxdo",
- ProviderSubject: subject,
- },
- TargetUserID: &targetUserID,
- ResolvedEmail: email,
- RedirectTo: redirectTo,
- BrowserSessionKey: browserSessionKey,
- UpstreamIdentityClaims: map[string]any{
- "email": email,
- "username": username,
- "subject": subject,
- "suggested_display_name": displayName,
- "suggested_avatar_url": avatarURL,
- },
+ Intent: oauthIntentBindCurrentUser,
+ Identity: identityKey,
+ TargetUserID: &targetUserID,
+ ResolvedEmail: email,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
CompletionResponse: map[string]any{
"redirect": redirectTo,
},
@@ -278,27 +280,60 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
return
}
+ existingIdentityUser, err := h.findOAuthIdentityUser(c.Request.Context(), identityKey)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+ if existingIdentityUser != nil {
+ tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), existingIdentityUser.Email, username, "")
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+ if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: oauthIntentLogin,
+ Identity: identityKey,
+ TargetUserID: &user.ID,
+ ResolvedEmail: existingIdentityUser.Email,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
+ CompletionResponse: map[string]any{
+ "access_token": tokenPair.AccessToken,
+ "refresh_token": tokenPair.RefreshToken,
+ "expires_in": tokenPair.ExpiresIn,
+ "token_type": "Bearer",
+ "redirect": redirectTo,
+ },
+ }); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+ return
+ }
+
+ if h.isForceEmailOnThirdPartySignup(c.Request.Context()) {
+ if err := h.createOAuthEmailRequiredPendingSession(c, identityKey, redirectTo, browserSessionKey, upstreamClaims); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+ return
+ }
+
// 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
if err != nil {
if errors.Is(err, service.ErrOAuthInvitationRequired) {
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
- Intent: "login",
- Identity: service.PendingAuthIdentityKey{
- ProviderType: "linuxdo",
- ProviderKey: "linuxdo",
- ProviderSubject: subject,
- },
- ResolvedEmail: email,
- RedirectTo: redirectTo,
- BrowserSessionKey: browserSessionKey,
- UpstreamIdentityClaims: map[string]any{
- "email": email,
- "username": username,
- "subject": subject,
- "suggested_display_name": displayName,
- "suggested_avatar_url": avatarURL,
- },
+ Intent: "login",
+ Identity: identityKey,
+ ResolvedEmail: email,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
CompletionResponse: map[string]any{
"error": "invitation_required",
"redirect": redirectTo,
@@ -316,23 +351,13 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
}
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
- Intent: "login",
- Identity: service.PendingAuthIdentityKey{
- ProviderType: "linuxdo",
- ProviderKey: "linuxdo",
- ProviderSubject: subject,
- },
- TargetUserID: &user.ID,
- ResolvedEmail: email,
- RedirectTo: redirectTo,
- BrowserSessionKey: browserSessionKey,
- UpstreamIdentityClaims: map[string]any{
- "email": email,
- "username": username,
- "subject": subject,
- "suggested_display_name": displayName,
- "suggested_avatar_url": avatarURL,
- },
+ Intent: "login",
+ Identity: identityKey,
+ TargetUserID: &user.ID,
+ ResolvedEmail: email,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
CompletionResponse: map[string]any{
"access_token": tokenPair.AccessToken,
"refresh_token": tokenPair.RefreshToken,
diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go
index 2d6c3714..99b9b406 100644
--- a/backend/internal/handler/auth_oauth_pending_flow.go
+++ b/backend/internal/handler/auth_oauth_pending_flow.go
@@ -46,6 +46,36 @@ type oauthAdoptionDecisionRequest struct {
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
}
+type bindPendingOAuthLoginRequest struct {
+ Email string `json:"email" binding:"required,email"`
+ Password string `json:"password" binding:"required"`
+ AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
+ AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
+}
+
+type createPendingOAuthAccountRequest struct {
+ Email string `json:"email" binding:"required,email"`
+ VerifyCode string `json:"verify_code,omitempty"`
+ Password string `json:"password" binding:"required,min=6"`
+ InvitationCode string `json:"invitation_code,omitempty"`
+ AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
+ AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
+}
+
+func (r bindPendingOAuthLoginRequest) adoptionDecision() oauthAdoptionDecisionRequest {
+ return oauthAdoptionDecisionRequest{
+ AdoptDisplayName: r.AdoptDisplayName,
+ AdoptAvatar: r.AdoptAvatar,
+ }
+}
+
+func (r createPendingOAuthAccountRequest) adoptionDecision() oauthAdoptionDecisionRequest {
+ return oauthAdoptionDecisionRequest{
+ AdoptDisplayName: r.AdoptDisplayName,
+ AdoptAvatar: r.AdoptAvatar,
+ }
+}
+
func (h *AuthHandler) pendingIdentityService() (*service.AuthPendingIdentityService, error) {
if h == nil || h.authService == nil || h.authService.EntClient() == nil {
return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
@@ -170,6 +200,36 @@ func readCompletionResponse(session map[string]any) (map[string]any, bool) {
return result, true
}
+func clonePendingMap(values map[string]any) map[string]any {
+ if len(values) == 0 {
+ return map[string]any{}
+ }
+ cloned := make(map[string]any, len(values))
+ for key, value := range values {
+ cloned[key] = value
+ }
+ return cloned
+}
+
+func mergePendingCompletionResponse(session *dbent.PendingAuthSession, overrides map[string]any) map[string]any {
+ payload, _ := readCompletionResponse(session.LocalFlowState)
+ merged := clonePendingMap(payload)
+ if strings.TrimSpace(session.RedirectTo) != "" {
+ if _, exists := merged["redirect"]; !exists {
+ merged["redirect"] = session.RedirectTo
+ }
+ }
+ for key, value := range overrides {
+ if value == nil {
+ delete(merged, key)
+ continue
+ }
+ merged[key] = value
+ }
+ applySuggestedProfileToCompletionResponse(merged, session.UpstreamIdentityClaims)
+ return merged
+}
+
func pendingSessionStringValue(values map[string]any, key string) string {
if len(values) == 0 {
return ""
@@ -264,6 +324,89 @@ func (h *AuthHandler) entClient() *dbent.Client {
return h.authService.EntClient()
}
+func (h *AuthHandler) isForceEmailOnThirdPartySignup(ctx context.Context) bool {
+ if h == nil || h.settingSvc == nil {
+ return false
+ }
+ defaults, err := h.settingSvc.GetAuthSourceDefaultSettings(ctx)
+ if err != nil || defaults == nil {
+ return false
+ }
+ return defaults.ForceEmailOnThirdPartySignup
+}
+
+func (h *AuthHandler) findOAuthIdentityUser(ctx context.Context, identity service.PendingAuthIdentityKey) (*dbent.User, error) {
+ client := h.entClient()
+ if client == nil {
+ return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+
+ record, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ(strings.TrimSpace(identity.ProviderType)),
+ authidentity.ProviderKeyEQ(strings.TrimSpace(identity.ProviderKey)),
+ authidentity.ProviderSubjectEQ(strings.TrimSpace(identity.ProviderSubject)),
+ ).
+ Only(ctx)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, nil
+ }
+ return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
+ }
+
+ userEntity, err := client.User.Get(ctx, record.UserID)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, nil
+ }
+ return nil, infraerrors.InternalServer("AUTH_IDENTITY_USER_LOOKUP_FAILED", "failed to load auth identity user").WithCause(err)
+ }
+ return userEntity, nil
+}
+
+func (h *AuthHandler) createOAuthEmailRequiredPendingSession(
+ c *gin.Context,
+ identity service.PendingAuthIdentityKey,
+ redirectTo string,
+ browserSessionKey string,
+ upstreamClaims map[string]any,
+) error {
+ return h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: oauthIntentLogin,
+ Identity: identity,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
+ CompletionResponse: map[string]any{
+ "redirect": redirectTo,
+ "step": "email_required",
+ "force_email_on_signup": true,
+ "email_binding_required": true,
+ "existing_account_bindable": true,
+ },
+ })
+}
+
+func (h *AuthHandler) BindLinuxDoOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "linuxdo") }
+func (h *AuthHandler) BindOIDCOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "oidc") }
+func (h *AuthHandler) BindWeChatOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "wechat") }
+func (h *AuthHandler) BindPendingOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "") }
+
+func (h *AuthHandler) CreateLinuxDoOAuthAccount(c *gin.Context) {
+ h.createPendingOAuthAccount(c, "linuxdo")
+}
+
+func (h *AuthHandler) CreateOIDCOAuthAccount(c *gin.Context) { h.createPendingOAuthAccount(c, "oidc") }
+
+func (h *AuthHandler) CreateWeChatOAuthAccount(c *gin.Context) {
+ h.createPendingOAuthAccount(c, "wechat")
+}
+
+func (h *AuthHandler) CreatePendingOAuthAccount(c *gin.Context) {
+ h.createPendingOAuthAccount(c, "")
+}
+
func (h *AuthHandler) upsertPendingOAuthAdoptionDecision(
c *gin.Context,
sessionID int64,
@@ -313,6 +456,60 @@ func (h *AuthHandler) upsertPendingOAuthAdoptionDecision(
return decision, nil
}
+func (h *AuthHandler) ensurePendingOAuthAdoptionDecision(
+ c *gin.Context,
+ sessionID int64,
+ req oauthAdoptionDecisionRequest,
+) (*dbent.IdentityAdoptionDecision, error) {
+ decision, err := h.upsertPendingOAuthAdoptionDecision(c, sessionID, req)
+ if err != nil {
+ return nil, err
+ }
+ if decision != nil {
+ return decision, nil
+ }
+
+ svc, err := h.pendingIdentityService()
+ if err != nil {
+ return nil, err
+ }
+ decision, err = svc.UpsertAdoptionDecision(c.Request.Context(), service.PendingIdentityAdoptionDecisionInput{
+ PendingAuthSessionID: sessionID,
+ })
+ if err != nil {
+ return nil, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_SAVE_FAILED", "failed to save oauth profile adoption decision").WithCause(err)
+ }
+ return decision, nil
+}
+
+func updatePendingOAuthSessionProgress(
+ ctx context.Context,
+ client *dbent.Client,
+ session *dbent.PendingAuthSession,
+ intent string,
+ resolvedEmail string,
+ targetUserID *int64,
+ completionResponse map[string]any,
+) (*dbent.PendingAuthSession, error) {
+ if client == nil || session == nil {
+ return nil, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth session is invalid")
+ }
+
+ localFlowState := clonePendingMap(session.LocalFlowState)
+ localFlowState[oauthCompletionResponseKey] = clonePendingMap(completionResponse)
+
+ update := client.PendingAuthSession.UpdateOneID(session.ID).
+ SetIntent(strings.TrimSpace(intent)).
+ SetResolvedEmail(strings.TrimSpace(resolvedEmail)).
+ SetLocalFlowState(localFlowState)
+ if targetUserID != nil && *targetUserID > 0 {
+ update = update.SetTargetUserID(*targetUserID)
+ } else {
+ update = update.ClearTargetUserID()
+ }
+ return update.Save(ctx)
+}
+
func resolvePendingOAuthTargetUserID(ctx context.Context, client *dbent.Client, session *dbent.PendingAuthSession) (int64, error) {
if session == nil {
return 0, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth session is invalid")
@@ -401,17 +598,18 @@ func shouldBindPendingOAuthIdentity(session *dbent.PendingAuthSession, decision
return decision.AdoptDisplayName || decision.AdoptAvatar
}
-func applyPendingOAuthAdoption(
+func applyPendingOAuthBinding(
ctx context.Context,
client *dbent.Client,
session *dbent.PendingAuthSession,
decision *dbent.IdentityAdoptionDecision,
overrideUserID *int64,
+ forceBind bool,
) error {
- if client == nil || session == nil || decision == nil {
+ if client == nil || session == nil {
return nil
}
- if !shouldBindPendingOAuthIdentity(session, decision) {
+ if !forceBind && !shouldBindPendingOAuthIdentity(session, decision) {
return nil
}
@@ -427,11 +625,11 @@ func applyPendingOAuthAdoption(
}
adoptedDisplayName := ""
- if decision.AdoptDisplayName {
+ if decision != nil && decision.AdoptDisplayName {
adoptedDisplayName = normalizeAdoptedOAuthDisplayName(pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_display_name"))
}
adoptedAvatarURL := ""
- if decision.AdoptAvatar {
+ if decision != nil && decision.AdoptAvatar {
adoptedAvatarURL = pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_avatar_url")
}
@@ -441,7 +639,7 @@ func applyPendingOAuthAdoption(
}
defer func() { _ = tx.Rollback() }()
- if decision.AdoptDisplayName && adoptedDisplayName != "" {
+ if decision != nil && decision.AdoptDisplayName && adoptedDisplayName != "" {
if err := tx.Client().User.UpdateOneID(targetUserID).
SetUsername(adoptedDisplayName).
Exec(ctx); err != nil {
@@ -458,10 +656,10 @@ func applyPendingOAuthAdoption(
for key, value := range session.UpstreamIdentityClaims {
metadata[key] = value
}
- if decision.AdoptDisplayName && adoptedDisplayName != "" {
+ if decision != nil && decision.AdoptDisplayName && adoptedDisplayName != "" {
metadata["display_name"] = adoptedDisplayName
}
- if decision.AdoptAvatar && adoptedAvatarURL != "" {
+ if decision != nil && decision.AdoptAvatar && adoptedAvatarURL != "" {
metadata["avatar_url"] = adoptedAvatarURL
}
@@ -473,7 +671,7 @@ func applyPendingOAuthAdoption(
return err
}
- if decision.IdentityID == nil || *decision.IdentityID != identity.ID {
+ if decision != nil && (decision.IdentityID == nil || *decision.IdentityID != identity.ID) {
if _, err := tx.Client().IdentityAdoptionDecision.UpdateOneID(decision.ID).
SetIdentityID(identity.ID).
Save(ctx); err != nil {
@@ -484,6 +682,16 @@ func applyPendingOAuthAdoption(
return tx.Commit()
}
+func applyPendingOAuthAdoption(
+ ctx context.Context,
+ client *dbent.Client,
+ session *dbent.PendingAuthSession,
+ decision *dbent.IdentityAdoptionDecision,
+ overrideUserID *int64,
+) error {
+ return applyPendingOAuthBinding(ctx, client, session, decision, overrideUserID, false)
+}
+
func applySuggestedProfileToCompletionResponse(payload map[string]any, upstream map[string]any) {
if len(payload) == 0 || len(upstream) == 0 {
return
@@ -507,6 +715,206 @@ func applySuggestedProfileToCompletionResponse(payload map[string]any, upstream
}
}
+func readPendingOAuthBrowserSession(c *gin.Context, h *AuthHandler) (*service.AuthPendingIdentityService, *dbent.PendingAuthSession, func(), error) {
+ secureCookie := isRequestHTTPS(c)
+ clearCookies := func() {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ }
+
+ sessionToken, err := readOAuthPendingSessionCookie(c)
+ if err != nil || strings.TrimSpace(sessionToken) == "" {
+ clearCookies()
+ return nil, nil, clearCookies, service.ErrPendingAuthSessionNotFound
+ }
+ browserSessionKey, err := readOAuthPendingBrowserCookie(c)
+ if err != nil || strings.TrimSpace(browserSessionKey) == "" {
+ clearCookies()
+ return nil, nil, clearCookies, service.ErrPendingAuthBrowserMismatch
+ }
+
+ svc, err := h.pendingIdentityService()
+ if err != nil {
+ clearCookies()
+ return nil, nil, clearCookies, err
+ }
+
+ session, err := svc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey)
+ if err != nil {
+ clearCookies()
+ return nil, nil, clearCookies, err
+ }
+
+ return svc, session, clearCookies, nil
+}
+
+func buildPendingOAuthSessionStatusPayload(session *dbent.PendingAuthSession) gin.H {
+ payload := gin.H{
+ "auth_result": "pending_session",
+ "provider": strings.TrimSpace(session.ProviderType),
+ "intent": strings.TrimSpace(session.Intent),
+ }
+ for key, value := range mergePendingCompletionResponse(session, nil) {
+ payload[key] = value
+ }
+ if email := strings.TrimSpace(session.ResolvedEmail); email != "" {
+ payload["email"] = email
+ }
+ return payload
+}
+
+func writeOAuthTokenPairResponse(c *gin.Context, tokenPair *service.TokenPair) {
+ c.JSON(http.StatusOK, gin.H{
+ "access_token": tokenPair.AccessToken,
+ "refresh_token": tokenPair.RefreshToken,
+ "expires_in": tokenPair.ExpiresIn,
+ "token_type": "Bearer",
+ })
+}
+
+func (h *AuthHandler) bindPendingOAuthLogin(c *gin.Context, provider string) {
+ var req bindPendingOAuthLoginRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ pendingSvc, session, clearCookies, err := readPendingOAuthBrowserSession(c, h)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if strings.TrimSpace(provider) != "" && !strings.EqualFold(strings.TrimSpace(session.ProviderType), provider) {
+ response.BadRequest(c, "Pending oauth session provider mismatch")
+ return
+ }
+
+ user, err := h.authService.ValidatePasswordCredentials(c.Request.Context(), strings.TrimSpace(req.Email), req.Password)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if session.TargetUserID != nil && *session.TargetUserID > 0 && user.ID != *session.TargetUserID {
+ response.ErrorFrom(c, infraerrors.Conflict("PENDING_AUTH_TARGET_USER_MISMATCH", "pending oauth session must be completed by the targeted user"))
+ return
+ }
+
+ decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := applyPendingOAuthBinding(c.Request.Context(), h.entClient(), session, decision, &user.ID, true); err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
+ return
+ }
+
+ h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
+ tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), user, "")
+ if err != nil {
+ response.InternalError(c, "Failed to generate token pair")
+ return
+ }
+ if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), session.SessionToken, session.BrowserSessionKey); err != nil {
+ clearCookies()
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ clearCookies()
+ writeOAuthTokenPairResponse(c, tokenPair)
+}
+
+func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string) {
+ var req createPendingOAuthAccountRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ pendingSvc, session, clearCookies, err := readPendingOAuthBrowserSession(c, h)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if strings.TrimSpace(provider) != "" && !strings.EqualFold(strings.TrimSpace(session.ProviderType), provider) {
+ response.BadRequest(c, "Pending oauth session provider mismatch")
+ return
+ }
+
+ client := h.entClient()
+ if client == nil {
+ response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready"))
+ return
+ }
+
+ email := strings.TrimSpace(strings.ToLower(req.Email))
+ existingUser, err := client.User.Query().Where(dbuser.EmailEQ(email)).Only(c.Request.Context())
+ if err != nil && !dbent.IsNotFound(err) {
+ response.ErrorFrom(c, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable"))
+ return
+ }
+ if existingUser != nil {
+ completionResponse := mergePendingCompletionResponse(session, map[string]any{
+ "step": "bind_login_required",
+ "email": email,
+ })
+ session, err = updatePendingOAuthSessionProgress(
+ c.Request.Context(),
+ client,
+ session,
+ "adopt_existing_user_by_email",
+ email,
+ &existingUser.ID,
+ completionResponse,
+ )
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_SESSION_UPDATE_FAILED", "failed to update pending oauth session").WithCause(err))
+ return
+ }
+
+ if _, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision()); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(session))
+ return
+ }
+
+ tokenPair, user, err := h.authService.RegisterOAuthEmailAccount(
+ c.Request.Context(),
+ email,
+ req.Password,
+ strings.TrimSpace(req.VerifyCode),
+ strings.TrimSpace(req.InvitationCode),
+ strings.TrimSpace(session.ProviderType),
+ )
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := applyPendingOAuthBinding(c.Request.Context(), client, session, decision, &user.ID, true); err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
+ return
+ }
+
+ if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), session.SessionToken, session.BrowserSessionKey); err != nil {
+ clearCookies()
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ clearCookies()
+ writeOAuthTokenPairResponse(c, tokenPair)
+}
+
// ExchangePendingOAuthCompletion redeems a pending OAuth browser session into a frontend-safe payload.
// POST /api/v1/auth/oauth/pending/exchange
func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go
index 3afb4fb7..80338b8a 100644
--- a/backend/internal/handler/auth_oauth_pending_flow_test.go
+++ b/backend/internal/handler/auth_oauth_pending_flow_test.go
@@ -509,9 +509,305 @@ func TestExchangePendingOAuthCompletionInvitationRequiredFalseFalsePersistsDecis
require.Nil(t, storedSession.ConsumedAt)
}
+func TestCreateOIDCOAuthAccountCreatesUserBindsIdentityAndConsumesSession(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "fresh@example.com", "246810")
+ ctx := context.Background()
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("create-account-session-token").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-create-123").
+ SetBrowserSessionKey("create-account-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ "suggested_display_name": "Fresh OIDC User",
+ "suggested_avatar_url": "https://cdn.example/fresh.png",
+ }).
+ SetRedirectTo("/profile").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"email":"fresh@example.com","verify_code":"246810","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("create-account-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.CreateOIDCOAuthAccount(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var payload map[string]any
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload))
+ require.NotEmpty(t, payload["access_token"])
+ require.NotEmpty(t, payload["refresh_token"])
+ require.Equal(t, "Bearer", payload["token_type"])
+
+ createdUser, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, service.StatusActive, createdUser.Status)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("oidc"),
+ authidentity.ProviderKeyEQ("https://issuer.example"),
+ authidentity.ProviderSubjectEQ("oidc-create-123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, createdUser.ID, identity.UserID)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.NotNil(t, storedSession.ConsumedAt)
+}
+
+func TestCreateOIDCOAuthAccountExistingEmailReturnsAdoptExistingUserByEmailState(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790")
+ ctx := context.Background()
+
+ existingUser, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("existing-email-session-token").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-existing-123").
+ SetBrowserSessionKey("existing-email-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ "suggested_display_name": "Existing OIDC User",
+ "suggested_avatar_url": "https://cdn.example/existing.png",
+ }).
+ SetRedirectTo("/dashboard").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"email":"owner@example.com","verify_code":"135790","password":"secret-123"}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("existing-email-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.CreateOIDCOAuthAccount(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var payload map[string]any
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload))
+ require.Equal(t, "pending_session", payload["auth_result"])
+ require.Equal(t, "adopt_existing_user_by_email", payload["intent"])
+ require.Equal(t, "oidc", payload["provider"])
+ require.Equal(t, "/dashboard", payload["redirect"])
+ require.Equal(t, true, payload["adoption_required"])
+ require.Equal(t, "Existing OIDC User", payload["suggested_display_name"])
+ require.Equal(t, "https://cdn.example/existing.png", payload["suggested_avatar_url"])
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Equal(t, "adopt_existing_user_by_email", storedSession.Intent)
+ require.NotNil(t, storedSession.TargetUserID)
+ require.Equal(t, existingUser.ID, *storedSession.TargetUserID)
+ require.Equal(t, "owner@example.com", storedSession.ResolvedEmail)
+ require.Nil(t, storedSession.ConsumedAt)
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("oidc"),
+ authidentity.ProviderKeyEQ("https://issuer.example"),
+ authidentity.ProviderSubjectEQ("oidc-existing-123"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, identityCount)
+}
+
+func TestBindOIDCOAuthLoginBindsExistingUserAndConsumesSession(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ passwordHash, err := handler.authService.HashPassword("secret-123")
+ require.NoError(t, err)
+
+ existingUser, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash(passwordHash).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("bind-login-session-token").
+ SetIntent("adopt_existing_user_by_email").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-bind-123").
+ SetTargetUserID(existingUser.ID).
+ SetResolvedEmail(existingUser.Email).
+ SetBrowserSessionKey("bind-login-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ "suggested_display_name": "Bound OIDC User",
+ "suggested_avatar_url": "https://cdn.example/bound.png",
+ }).
+ SetRedirectTo("/profile").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.BindOIDCOAuthLogin(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var payload map[string]any
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload))
+ require.NotEmpty(t, payload["access_token"])
+ require.NotEmpty(t, payload["refresh_token"])
+ require.Equal(t, "Bearer", payload["token_type"])
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("oidc"),
+ authidentity.ProviderKeyEQ("https://issuer.example"),
+ authidentity.ProviderSubjectEQ("oidc-bind-123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, existingUser.ID, identity.UserID)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.NotNil(t, storedSession.ConsumedAt)
+}
+
+func TestBindOIDCOAuthLoginRejectsInvalidPasswordWithoutConsumingSession(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ passwordHash, err := handler.authService.HashPassword("secret-123")
+ require.NoError(t, err)
+
+ existingUser, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash(passwordHash).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("bind-login-invalid-password-session-token").
+ SetIntent("adopt_existing_user_by_email").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-bind-invalid-123").
+ SetTargetUserID(existingUser.ID).
+ SetResolvedEmail(existingUser.Email).
+ SetBrowserSessionKey("bind-login-invalid-password-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ "suggested_display_name": "Bound OIDC User",
+ "suggested_avatar_url": "https://cdn.example/bound.png",
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"email":"owner@example.com","password":"wrong-password"}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-invalid-password-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.BindOIDCOAuthLogin(ginCtx)
+
+ require.Equal(t, http.StatusUnauthorized, recorder.Code)
+ payload := decodeJSONBody(t, recorder)
+ require.Equal(t, "INVALID_CREDENTIALS", payload["reason"])
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("oidc"),
+ authidentity.ProviderKeyEQ("https://issuer.example"),
+ authidentity.ProviderSubjectEQ("oidc-bind-invalid-123"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, identityCount)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
func newOAuthPendingFlowTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) {
t.Helper()
+ return newOAuthPendingFlowTestHandlerWithOptions(t, invitationEnabled, false, nil)
+}
+
+func newOAuthPendingFlowTestHandlerWithEmailVerification(
+ t *testing.T,
+ invitationEnabled bool,
+ email string,
+ code string,
+) (*AuthHandler, *dbent.Client) {
+ t.Helper()
+
+ cache := &oauthPendingFlowEmailCacheStub{
+ verificationCodes: map[string]*service.VerificationCodeData{
+ email: {
+ Code: code,
+ Attempts: 0,
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
+ },
+ },
+ }
+ return newOAuthPendingFlowTestHandlerWithOptions(t, invitationEnabled, true, cache)
+}
+
+func newOAuthPendingFlowTestHandlerWithOptions(
+ t *testing.T,
+ invitationEnabled bool,
+ emailVerifyEnabled bool,
+ emailCache service.EmailCache,
+) (*AuthHandler, *dbent.Client) {
+ t.Helper()
+
db, err := sql.Open("sqlite", "file:auth_oauth_pending_flow_handler?mode=memory&cache=shared")
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
@@ -538,9 +834,18 @@ func newOAuthPendingFlowTestHandler(t *testing.T, invitationEnabled bool) (*Auth
values: map[string]string{
service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled),
+ service.SettingKeyEmailVerifyEnabled: boolSettingValue(emailVerifyEnabled),
},
}, cfg)
userRepo := &oauthPendingFlowUserRepo{client: client}
+ var emailService *service.EmailService
+ if emailCache != nil {
+ emailService = service.NewEmailService(&oauthPendingFlowSettingRepoStub{
+ values: map[string]string{
+ service.SettingKeyEmailVerifyEnabled: boolSettingValue(emailVerifyEnabled),
+ },
+ }, emailCache)
+ }
authSvc := service.NewAuthService(
client,
userRepo,
@@ -548,7 +853,7 @@ func newOAuthPendingFlowTestHandler(t *testing.T, invitationEnabled bool) (*Auth
&oauthPendingFlowRefreshTokenCacheStub{},
cfg,
settingSvc,
- nil,
+ emailService,
nil,
nil,
nil,
@@ -622,6 +927,70 @@ func (s *oauthPendingFlowSettingRepoStub) Delete(context.Context, string) error
type oauthPendingFlowRefreshTokenCacheStub struct{}
+type oauthPendingFlowEmailCacheStub struct {
+ verificationCodes map[string]*service.VerificationCodeData
+}
+
+func (s *oauthPendingFlowEmailCacheStub) GetVerificationCode(_ context.Context, email string) (*service.VerificationCodeData, error) {
+ if s == nil || s.verificationCodes == nil {
+ return nil, nil
+ }
+ return s.verificationCodes[email], nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) SetVerificationCode(_ context.Context, email string, data *service.VerificationCodeData, _ time.Duration) error {
+ if s.verificationCodes == nil {
+ s.verificationCodes = map[string]*service.VerificationCodeData{}
+ }
+ s.verificationCodes[email] = data
+ return nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) DeleteVerificationCode(_ context.Context, email string) error {
+ delete(s.verificationCodes, email)
+ return nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) GetNotifyVerifyCode(context.Context, string) (*service.VerificationCodeData, error) {
+ return nil, nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) SetNotifyVerifyCode(context.Context, string, *service.VerificationCodeData, time.Duration) error {
+ return nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) DeleteNotifyVerifyCode(context.Context, string) error {
+ return nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) GetPasswordResetToken(context.Context, string) (*service.PasswordResetTokenData, error) {
+ return nil, nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) SetPasswordResetToken(context.Context, string, *service.PasswordResetTokenData, time.Duration) error {
+ return nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) DeletePasswordResetToken(context.Context, string) error {
+ return nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) IsPasswordResetEmailInCooldown(context.Context, string) bool {
+ return false
+}
+
+func (s *oauthPendingFlowEmailCacheStub) SetPasswordResetEmailCooldown(context.Context, string, time.Duration) error {
+ return nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) IncrNotifyCodeUserRate(context.Context, int64, time.Duration) (int64, error) {
+ return 0, nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) GetNotifyCodeUserRate(context.Context, int64) (int64, error) {
+ return 0, nil
+}
+
func (s *oauthPendingFlowRefreshTokenCacheStub) StoreRefreshToken(context.Context, string, *service.RefreshTokenData, time.Duration) error {
return nil
}
diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go
index 0f79759e..909d6379 100644
--- a/backend/internal/handler/auth_oidc_oauth.go
+++ b/backend/internal/handler/auth_oidc_oauth.go
@@ -342,6 +342,21 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
idClaims.Name,
oidcFallbackUsername(subject),
)
+ identityRef := service.PendingAuthIdentityKey{
+ ProviderType: "oidc",
+ ProviderKey: issuer,
+ ProviderSubject: subject,
+ }
+ upstreamClaims := map[string]any{
+ "email": email,
+ "username": username,
+ "subject": subject,
+ "issuer": issuer,
+ "email_verified": emailVerified != nil && *emailVerified,
+ "provider_fallback": strings.TrimSpace(cfg.ProviderName),
+ "suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username),
+ "suggested_avatar_url": userInfoClaims.AvatarURL,
+ }
if intent == oauthIntentBindCurrentUser {
targetUserID, err := h.readOAuthBindUserIDFromCookie(c, oidcOAuthBindUserCookieName)
if err != nil {
@@ -349,26 +364,13 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
return
}
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
- Intent: oauthIntentBindCurrentUser,
- Identity: service.PendingAuthIdentityKey{
- ProviderType: "oidc",
- ProviderKey: issuer,
- ProviderSubject: subject,
- },
- TargetUserID: &targetUserID,
- ResolvedEmail: email,
- RedirectTo: redirectTo,
- BrowserSessionKey: browserSessionKey,
- UpstreamIdentityClaims: map[string]any{
- "email": email,
- "username": username,
- "subject": subject,
- "issuer": issuer,
- "email_verified": emailVerified != nil && *emailVerified,
- "provider_fallback": strings.TrimSpace(cfg.ProviderName),
- "suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username),
- "suggested_avatar_url": userInfoClaims.AvatarURL,
- },
+ Intent: oauthIntentBindCurrentUser,
+ Identity: identityRef,
+ TargetUserID: &targetUserID,
+ ResolvedEmail: email,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
CompletionResponse: map[string]any{
"redirect": redirectTo,
},
@@ -380,30 +382,60 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
return
}
+ existingIdentityUser, err := h.findOAuthIdentityUser(c.Request.Context(), identityRef)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+ if existingIdentityUser != nil {
+ tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), existingIdentityUser.Email, username, "")
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+ if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: oauthIntentLogin,
+ Identity: identityRef,
+ TargetUserID: &user.ID,
+ ResolvedEmail: existingIdentityUser.Email,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
+ CompletionResponse: map[string]any{
+ "access_token": tokenPair.AccessToken,
+ "refresh_token": tokenPair.RefreshToken,
+ "expires_in": tokenPair.ExpiresIn,
+ "token_type": "Bearer",
+ "redirect": redirectTo,
+ },
+ }); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+ return
+ }
+
+ if h.isForceEmailOnThirdPartySignup(c.Request.Context()) {
+ if err := h.createOAuthEmailRequiredPendingSession(c, identityRef, redirectTo, browserSessionKey, upstreamClaims); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+ return
+ }
+
// 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
if err != nil {
if errors.Is(err, service.ErrOAuthInvitationRequired) {
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
- Intent: "login",
- Identity: service.PendingAuthIdentityKey{
- ProviderType: "oidc",
- ProviderKey: issuer,
- ProviderSubject: subject,
- },
- ResolvedEmail: email,
- RedirectTo: redirectTo,
- BrowserSessionKey: browserSessionKey,
- UpstreamIdentityClaims: map[string]any{
- "email": email,
- "username": username,
- "subject": subject,
- "issuer": issuer,
- "email_verified": emailVerified != nil && *emailVerified,
- "provider_fallback": strings.TrimSpace(cfg.ProviderName),
- "suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username),
- "suggested_avatar_url": userInfoClaims.AvatarURL,
- },
+ Intent: "login",
+ Identity: identityRef,
+ ResolvedEmail: email,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
CompletionResponse: map[string]any{
"error": "invitation_required",
"redirect": redirectTo,
@@ -420,26 +452,13 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
}
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
- Intent: "login",
- Identity: service.PendingAuthIdentityKey{
- ProviderType: "oidc",
- ProviderKey: issuer,
- ProviderSubject: subject,
- },
- TargetUserID: &user.ID,
- ResolvedEmail: email,
- RedirectTo: redirectTo,
- BrowserSessionKey: browserSessionKey,
- UpstreamIdentityClaims: map[string]any{
- "email": email,
- "username": username,
- "subject": subject,
- "issuer": issuer,
- "email_verified": emailVerified != nil && *emailVerified,
- "provider_fallback": strings.TrimSpace(cfg.ProviderName),
- "suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username),
- "suggested_avatar_url": userInfoClaims.AvatarURL,
- },
+ Intent: "login",
+ Identity: identityRef,
+ TargetUserID: &user.ID,
+ ResolvedEmail: email,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
CompletionResponse: map[string]any{
"access_token": tokenPair.AccessToken,
"refresh_token": tokenPair.RefreshToken,
diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go
index 45ac6cad..6d37c799 100644
--- a/backend/internal/handler/auth_wechat_oauth.go
+++ b/backend/internal/handler/auth_wechat_oauth.go
@@ -214,6 +214,11 @@ func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) {
"suggested_display_name": strings.TrimSpace(userInfo.Nickname),
"suggested_avatar_url": strings.TrimSpace(userInfo.HeadImgURL),
}
+ identityRef := service.PendingAuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: wechatOAuthProviderKey,
+ ProviderSubject: providerSubject,
+ }
normalizedIntent := normalizeWeChatOAuthIntent(intent)
if normalizedIntent == wechatOAuthIntentBind {
@@ -232,6 +237,34 @@ func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) {
return
}
+ existingIdentityUser, err := h.findOAuthIdentityUser(c.Request.Context(), identityRef)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+ if existingIdentityUser != nil {
+ tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), existingIdentityUser.Email, username, "")
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+ if err := h.createWeChatPendingSession(c, normalizedIntent, providerSubject, existingIdentityUser.Email, redirectTo, browserSessionKey, upstreamClaims, tokenPair, nil, &user.ID); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+ return
+ }
+
+ if h.isForceEmailOnThirdPartySignup(c.Request.Context()) {
+ if err := h.createOAuthEmailRequiredPendingSession(c, identityRef, redirectTo, browserSessionKey, upstreamClaims); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+ return
+ }
+
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
if err != nil {
if err := h.createWeChatPendingSession(c, normalizedIntent, providerSubject, email, redirectTo, browserSessionKey, upstreamClaims, tokenPair, err, nil); err != nil {
diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go
index f44b3e3b..637e317b 100644
--- a/backend/internal/handler/dto/settings.go
+++ b/backend/internal/handler/dto/settings.go
@@ -167,6 +167,7 @@ type DefaultSubscriptionSetting struct {
type PublicSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
+ ForceEmailOnThirdPartySignup bool `json:"force_email_on_third_party_signup"`
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go
index c7bc3e2a..9925f066 100644
--- a/backend/internal/handler/setting_handler.go
+++ b/backend/internal/handler/setting_handler.go
@@ -34,6 +34,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
response.Success(c, dto.PublicSettings{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
+ ForceEmailOnThirdPartySignup: settings.ForceEmailOnThirdPartySignup,
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
PromoCodeEnabled: settings.PromoCodeEnabled,
PasswordResetEnabled: settings.PasswordResetEnabled,
diff --git a/backend/internal/handler/setting_handler_public_test.go b/backend/internal/handler/setting_handler_public_test.go
new file mode 100644
index 00000000..114c7245
--- /dev/null
+++ b/backend/internal/handler/setting_handler_public_test.go
@@ -0,0 +1,83 @@
+//go:build unit
+
+package handler
+
+import (
+ "context"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+type settingHandlerPublicRepoStub struct {
+ values map[string]string
+}
+
+func (s *settingHandlerPublicRepoStub) Get(ctx context.Context, key string) (*service.Setting, error) {
+ panic("unexpected Get call")
+}
+
+func (s *settingHandlerPublicRepoStub) GetValue(ctx context.Context, key string) (string, error) {
+ panic("unexpected GetValue call")
+}
+
+func (s *settingHandlerPublicRepoStub) Set(ctx context.Context, key, value string) error {
+ panic("unexpected Set call")
+}
+
+func (s *settingHandlerPublicRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
+ out := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if value, ok := s.values[key]; ok {
+ out[key] = value
+ }
+ }
+ return out, nil
+}
+
+func (s *settingHandlerPublicRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
+ panic("unexpected SetMultiple call")
+}
+
+func (s *settingHandlerPublicRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
+ panic("unexpected GetAll call")
+}
+
+func (s *settingHandlerPublicRepoStub) Delete(ctx context.Context, key string) error {
+ panic("unexpected Delete call")
+}
+
+func TestSettingHandler_GetPublicSettings_ExposesForceEmailOnThirdPartySignup(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ repo := &settingHandlerPublicRepoStub{
+ values: map[string]string{
+ service.SettingKeyForceEmailOnThirdPartySignup: "true",
+ },
+ }
+ h := NewSettingHandler(service.NewSettingService(repo, &config.Config{}), "test-version")
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/settings/public", nil)
+
+ h.GetPublicSettings(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data struct {
+ ForceEmailOnThirdPartySignup bool `json:"force_email_on_third_party_signup"`
+ } `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.True(t, resp.Data.ForceEmailOnThirdPartySignup)
+}
diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go
index 7a34834d..1f28e9c3 100644
--- a/backend/internal/server/routes/auth.go
+++ b/backend/internal/server/routes/auth.go
@@ -72,18 +72,54 @@ func RegisterAuthRoutes(
}),
h.Auth.ExchangePendingOAuthCompletion,
)
+ auth.POST("/oauth/pending/create-account",
+ rateLimiter.LimitWithOptions("oauth-pending-create-account", 10, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.CreatePendingOAuthAccount,
+ )
+ auth.POST("/oauth/pending/bind-login",
+ rateLimiter.LimitWithOptions("oauth-pending-bind-login", 10, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.BindPendingOAuthLogin,
+ )
auth.POST("/oauth/linuxdo/complete-registration",
rateLimiter.LimitWithOptions("oauth-linuxdo-complete", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.CompleteLinuxDoOAuthRegistration,
)
+ auth.POST("/oauth/linuxdo/bind-login",
+ rateLimiter.LimitWithOptions("oauth-linuxdo-bind-login", 20, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.BindLinuxDoOAuthLogin,
+ )
+ auth.POST("/oauth/linuxdo/create-account",
+ rateLimiter.LimitWithOptions("oauth-linuxdo-create-account", 10, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.CreateLinuxDoOAuthAccount,
+ )
auth.POST("/oauth/wechat/complete-registration",
rateLimiter.LimitWithOptions("oauth-wechat-complete", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.CompleteWeChatOAuthRegistration,
)
+ auth.POST("/oauth/wechat/bind-login",
+ rateLimiter.LimitWithOptions("oauth-wechat-bind-login", 20, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.BindWeChatOAuthLogin,
+ )
+ auth.POST("/oauth/wechat/create-account",
+ rateLimiter.LimitWithOptions("oauth-wechat-create-account", 10, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.CreateWeChatOAuthAccount,
+ )
auth.GET("/oauth/oidc/start", h.Auth.OIDCOAuthStart)
auth.GET("/oauth/oidc/callback", h.Auth.OIDCOAuthCallback)
auth.POST("/oauth/oidc/complete-registration",
@@ -92,6 +128,18 @@ func RegisterAuthRoutes(
}),
h.Auth.CompleteOIDCOAuthRegistration,
)
+ auth.POST("/oauth/oidc/bind-login",
+ rateLimiter.LimitWithOptions("oauth-oidc-bind-login", 20, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.BindOIDCOAuthLogin,
+ )
+ auth.POST("/oauth/oidc/create-account",
+ rateLimiter.LimitWithOptions("oauth-oidc-create-account", 10, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.CreateOIDCOAuthAccount,
+ )
}
// 公开设置(无需认证)
diff --git a/backend/internal/service/auth_oauth_email_flow.go b/backend/internal/service/auth_oauth_email_flow.go
new file mode 100644
index 00000000..ca3403d4
--- /dev/null
+++ b/backend/internal/service/auth_oauth_email_flow.go
@@ -0,0 +1,151 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "strings"
+)
+
+// VerifyOAuthEmailCode verifies the locally entered email verification code for
+// third-party signup and binding flows. This is intentionally independent from
+// the global registration email verification toggle.
+func (s *AuthService) VerifyOAuthEmailCode(ctx context.Context, email, verifyCode string) error {
+ email = strings.TrimSpace(strings.ToLower(email))
+ verifyCode = strings.TrimSpace(verifyCode)
+
+ if email == "" {
+ return ErrEmailVerifyRequired
+ }
+ if verifyCode == "" {
+ return ErrEmailVerifyRequired
+ }
+ if s == nil || s.emailService == nil {
+ return ErrServiceUnavailable
+ }
+ return s.emailService.VerifyCode(ctx, email, verifyCode)
+}
+
+// RegisterOAuthEmailAccount creates a local account from a third-party first
+// login after the user has verified a local email address.
+func (s *AuthService) RegisterOAuthEmailAccount(
+ ctx context.Context,
+ email string,
+ password string,
+ verifyCode string,
+ invitationCode string,
+ signupSource string,
+) (*TokenPair, *User, error) {
+ if s == nil {
+ return nil, nil, ErrServiceUnavailable
+ }
+ if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
+ return nil, nil, ErrRegDisabled
+ }
+
+ email = strings.TrimSpace(strings.ToLower(email))
+ if isReservedEmail(email) {
+ return nil, nil, ErrEmailReserved
+ }
+ if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil {
+ return nil, nil, err
+ }
+ if err := s.VerifyOAuthEmailCode(ctx, email, verifyCode); err != nil {
+ return nil, nil, err
+ }
+
+ var invitationRedeemCode *RedeemCode
+ if s.settingService.IsInvitationCodeEnabled(ctx) {
+ if invitationCode == "" {
+ return nil, nil, ErrInvitationCodeRequired
+ }
+ redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode)
+ if err != nil {
+ return nil, nil, ErrInvitationCodeInvalid
+ }
+ if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused {
+ return nil, nil, ErrInvitationCodeInvalid
+ }
+ invitationRedeemCode = redeemCode
+ }
+
+ existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
+ if err != nil {
+ return nil, nil, ErrServiceUnavailable
+ }
+ if existsEmail {
+ return nil, nil, ErrEmailExists
+ }
+
+ hashedPassword, err := s.HashPassword(password)
+ if err != nil {
+ return nil, nil, fmt.Errorf("hash password: %w", err)
+ }
+
+ signupSource = strings.TrimSpace(strings.ToLower(signupSource))
+ if signupSource == "" {
+ signupSource = "email"
+ }
+ grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
+
+ user := &User{
+ Email: email,
+ PasswordHash: hashedPassword,
+ Role: RoleUser,
+ Balance: grantPlan.Balance,
+ Concurrency: grantPlan.Concurrency,
+ Status: StatusActive,
+ }
+
+ if err := s.userRepo.Create(ctx, user); err != nil {
+ if errors.Is(err, ErrEmailExists) {
+ return nil, nil, ErrEmailExists
+ }
+ return nil, nil, ErrServiceUnavailable
+ }
+
+ s.postAuthUserBootstrap(ctx, user, signupSource, true)
+ s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
+
+ if invitationRedeemCode != nil {
+ if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
+ return nil, nil, ErrInvitationCodeInvalid
+ }
+ }
+
+ tokenPair, err := s.GenerateTokenPair(ctx, user, "")
+ if err != nil {
+ return nil, nil, fmt.Errorf("generate token pair: %w", err)
+ }
+ return tokenPair, user, nil
+}
+
+// ValidatePasswordCredentials checks the local password without completing the
+// login flow. This is used by pending third-party account adoption flows before
+// the external identity has been bound.
+func (s *AuthService) ValidatePasswordCredentials(ctx context.Context, email, password string) (*User, error) {
+ if s == nil {
+ return nil, ErrServiceUnavailable
+ }
+
+ user, err := s.userRepo.GetByEmail(ctx, strings.TrimSpace(strings.ToLower(email)))
+ if err != nil {
+ if errors.Is(err, ErrUserNotFound) {
+ return nil, ErrInvalidCredentials
+ }
+ return nil, ErrServiceUnavailable
+ }
+ if !user.IsActive() {
+ return nil, ErrUserNotActive
+ }
+ if !s.CheckPassword(password, user.PasswordHash) {
+ return nil, ErrInvalidCredentials
+ }
+ return user, nil
+}
+
+// RecordSuccessfulLogin updates last-login activity after a non-standard login
+// flow finishes with a real session.
+func (s *AuthService) RecordSuccessfulLogin(ctx context.Context, userID int64) {
+ s.touchUserLogin(ctx, userID)
+}
diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go
index de555478..a2644fcd 100644
--- a/backend/internal/service/setting_service.go
+++ b/backend/internal/service/setting_service.go
@@ -217,6 +217,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
keys := []string{
SettingKeyRegistrationEnabled,
SettingKeyEmailVerifyEnabled,
+ SettingKeyForceEmailOnThirdPartySignup,
SettingKeyRegistrationEmailSuffixWhitelist,
SettingKeyPromoCodeEnabled,
SettingKeyPasswordResetEnabled,
@@ -294,6 +295,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
return &PublicSettings{
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: emailVerifyEnabled,
+ ForceEmailOnThirdPartySignup: settings[SettingKeyForceEmailOnThirdPartySignup] == "true",
RegistrationEmailSuffixWhitelist: registrationEmailSuffixWhitelist,
PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用
PasswordResetEnabled: passwordResetEnabled,
diff --git a/backend/internal/service/setting_service_public_test.go b/backend/internal/service/setting_service_public_test.go
index 5cf1e860..bb97c2aa 100644
--- a/backend/internal/service/setting_service_public_test.go
+++ b/backend/internal/service/setting_service_public_test.go
@@ -77,3 +77,16 @@ func TestSettingService_GetPublicSettings_ExposesTablePreferences(t *testing.T)
require.Equal(t, 50, settings.TableDefaultPageSize)
require.Equal(t, []int{20, 50, 100}, settings.TablePageSizeOptions)
}
+
+func TestSettingService_GetPublicSettings_ExposesForceEmailOnThirdPartySignup(t *testing.T) {
+ repo := &settingPublicRepoStub{
+ values: map[string]string{
+ SettingKeyForceEmailOnThirdPartySignup: "true",
+ },
+ }
+ svc := NewSettingService(repo, &config.Config{})
+
+ settings, err := svc.GetPublicSettings(context.Background())
+ require.NoError(t, err)
+ require.True(t, settings.ForceEmailOnThirdPartySignup)
+}
diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go
index e991ebef..72db4e31 100644
--- a/backend/internal/service/settings_view.go
+++ b/backend/internal/service/settings_view.go
@@ -128,6 +128,7 @@ type DefaultSubscriptionSetting struct {
type PublicSettings struct {
RegistrationEnabled bool
EmailVerifyEnabled bool
+ ForceEmailOnThirdPartySignup bool
RegistrationEmailSuffixWhitelist []string
PromoCodeEnabled bool
PasswordResetEnabled bool
--
GitLab
From 6ea3f42e2f825f165c98a63fa6b5f472f6853b33 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Mon, 20 Apr 2026 19:30:19 +0800
Subject: [PATCH 040/261] feat: add oauth callback email binding ui
---
.../api/__tests__/auth-oauth-adoption.spec.ts | 91 +++++++
frontend/src/api/auth.ts | 98 +++++--
frontend/src/stores/app.ts | 1 +
frontend/src/types/index.ts | 1 +
.../src/views/auth/LinuxDoCallbackView.vue | 257 +++++++++++++++++-
frontend/src/views/auth/OidcCallbackView.vue | 128 ++++++++-
.../src/views/auth/WechatCallbackView.vue | 108 ++++++++
.../__tests__/LinuxDoCallbackView.spec.ts | 105 +++++++
.../auth/__tests__/OidcCallbackView.spec.ts | 71 +++++
.../auth/__tests__/WechatCallbackView.spec.ts | 88 ++++++
10 files changed, 914 insertions(+), 34 deletions(-)
diff --git a/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts b/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts
index 9c0b4d55..f95332fb 100644
--- a/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts
+++ b/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts
@@ -30,6 +30,20 @@ describe('oauth adoption auth api', () => {
})
})
+ it('posts bind-login decisions when finalizing pending oauth bind flow', async () => {
+ const { completePendingOAuthBindLogin } = await import('@/api/auth')
+
+ await completePendingOAuthBindLogin({
+ adoptDisplayName: true,
+ adoptAvatar: false
+ })
+
+ expect(post).toHaveBeenCalledWith('/auth/oauth/pending/exchange', {
+ adopt_display_name: true,
+ adopt_avatar: false
+ })
+ })
+
it('posts linuxdo invitation completion with adoption decisions', async () => {
const { completeLinuxDoOAuthRegistration } = await import('@/api/auth')
@@ -45,6 +59,21 @@ describe('oauth adoption auth api', () => {
})
})
+ it('posts linuxdo create-account completion with adoption decisions', async () => {
+ const { createPendingLinuxDoOAuthAccount } = await import('@/api/auth')
+
+ await createPendingLinuxDoOAuthAccount('invite-code', {
+ adoptDisplayName: false,
+ adoptAvatar: true
+ })
+
+ expect(post).toHaveBeenCalledWith('/auth/oauth/linuxdo/complete-registration', {
+ invitation_code: 'invite-code',
+ adopt_display_name: false,
+ adopt_avatar: true
+ })
+ })
+
it('posts oidc invitation completion with adoption decisions', async () => {
const { completeOIDCOAuthRegistration } = await import('@/api/auth')
@@ -60,6 +89,21 @@ describe('oauth adoption auth api', () => {
})
})
+ it('posts oidc create-account completion with adoption decisions', async () => {
+ const { createPendingOIDCOAuthAccount } = await import('@/api/auth')
+
+ await createPendingOIDCOAuthAccount('invite-code', {
+ adoptDisplayName: true,
+ adoptAvatar: false
+ })
+
+ expect(post).toHaveBeenCalledWith('/auth/oauth/oidc/complete-registration', {
+ invitation_code: 'invite-code',
+ adopt_display_name: true,
+ adopt_avatar: false
+ })
+ })
+
it('posts wechat invitation completion with adoption decisions', async () => {
const { completeWeChatOAuthRegistration } = await import('@/api/auth')
@@ -75,6 +119,21 @@ describe('oauth adoption auth api', () => {
})
})
+ it('posts wechat create-account completion with adoption decisions', async () => {
+ const { createPendingWeChatOAuthAccount } = await import('@/api/auth')
+
+ await createPendingWeChatOAuthAccount('invite-code', {
+ adoptDisplayName: false,
+ adoptAvatar: false
+ })
+
+ expect(post).toHaveBeenCalledWith('/auth/oauth/wechat/complete-registration', {
+ invitation_code: 'invite-code',
+ adopt_display_name: false,
+ adopt_avatar: false
+ })
+ })
+
it('classifies oauth completion results as login or bind', async () => {
const { getOAuthCompletionKind } = await import('@/api/auth')
@@ -82,6 +141,38 @@ describe('oauth adoption auth api', () => {
expect(getOAuthCompletionKind({ redirect: '/profile' })).toBe('bind')
})
+ it('provides bind-login utility helpers for invitation and suggested profile states', async () => {
+ const {
+ getPendingOAuthBindLoginKind,
+ hasPendingOAuthSuggestedProfile,
+ isPendingOAuthCreateAccountRequired
+ } = await import('@/api/auth')
+
+ expect(getPendingOAuthBindLoginKind({ access_token: 'access-token' })).toBe('login')
+ expect(getPendingOAuthBindLoginKind({ redirect: '/profile' })).toBe('bind')
+ expect(
+ isPendingOAuthCreateAccountRequired({
+ error: 'invitation_required'
+ })
+ ).toBe(true)
+ expect(
+ isPendingOAuthCreateAccountRequired({
+ error: 'other'
+ })
+ ).toBe(false)
+ expect(
+ hasPendingOAuthSuggestedProfile({
+ suggested_display_name: 'OAuth Nick'
+ })
+ ).toBe(true)
+ expect(
+ hasPendingOAuthSuggestedProfile({
+ suggested_avatar_url: 'https://cdn.example/avatar.png'
+ })
+ ).toBe(true)
+ expect(hasPendingOAuthSuggestedProfile({})).toBe(false)
+ })
+
it('prepares an oauth bind access token cookie before redirect binding', async () => {
localStorage.setItem('auth_token', 'access-token-value')
const setCookie = vi.fn()
diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts
index c11bd90b..98b20154 100644
--- a/frontend/src/api/auth.ts
+++ b/frontend/src/api/auth.ts
@@ -193,7 +193,7 @@ export interface OAuthTokenResponse {
token_type?: string
}
-export interface PendingOAuthExchangeResponse extends Partial {
+export interface PendingOAuthBindLoginResponse extends Partial {
redirect?: string
error?: string
adoption_required?: boolean
@@ -201,6 +201,10 @@ export interface PendingOAuthExchangeResponse extends Partial
+): boolean {
+ return completion.error === 'invitation_required'
+}
+
+export function hasPendingOAuthSuggestedProfile(
+ completion: Pick<
+ PendingOAuthBindLoginResponse,
+ 'suggested_display_name' | 'suggested_avatar_url'
+ >
+): boolean {
+ return Boolean(completion.suggested_display_name || completion.suggested_avatar_url)
+}
+
export function persistOAuthTokenContext(tokens: Partial): void {
if (tokens.refresh_token) {
setRefreshToken(tokens.refresh_token)
@@ -431,11 +456,7 @@ export async function completeLinuxDoOAuthRegistration(
invitationCode: string,
decision?: OAuthAdoptionDecision
): Promise {
- const { data } = await apiClient.post('/auth/oauth/linuxdo/complete-registration', {
- invitation_code: invitationCode,
- ...serializeOAuthAdoptionDecision(decision)
- })
- return data
+ return createPendingLinuxDoOAuthAccount(invitationCode, decision)
}
/**
@@ -447,34 +468,68 @@ export async function completeOIDCOAuthRegistration(
invitationCode: string,
decision?: OAuthAdoptionDecision
): Promise {
- const { data } = await apiClient.post('/auth/oauth/oidc/complete-registration', {
- invitation_code: invitationCode,
- ...serializeOAuthAdoptionDecision(decision)
- })
- return data
+ return createPendingOIDCOAuthAccount(invitationCode, decision)
}
export async function completeWeChatOAuthRegistration(
invitationCode: string,
decision?: OAuthAdoptionDecision
): Promise {
- const { data } = await apiClient.post('/auth/oauth/wechat/complete-registration', {
- invitation_code: invitationCode,
- ...serializeOAuthAdoptionDecision(decision)
- })
+ return createPendingWeChatOAuthAccount(invitationCode, decision)
+}
+
+async function createPendingOAuthAccount(
+ provider: 'linuxdo' | 'oidc' | 'wechat',
+ invitationCode: string,
+ decision?: OAuthAdoptionDecision
+): Promise {
+ const { data } = await apiClient.post(
+ `/auth/oauth/${provider}/complete-registration`,
+ {
+ invitation_code: invitationCode,
+ ...serializeOAuthAdoptionDecision(decision)
+ }
+ )
return data
}
-export async function exchangePendingOAuthCompletion(
+export async function createPendingLinuxDoOAuthAccount(
+ invitationCode: string,
decision?: OAuthAdoptionDecision
-): Promise {
- const { data } = await apiClient.post(
+): Promise {
+ return createPendingOAuthAccount('linuxdo', invitationCode, decision)
+}
+
+export async function createPendingOIDCOAuthAccount(
+ invitationCode: string,
+ decision?: OAuthAdoptionDecision
+): Promise {
+ return createPendingOAuthAccount('oidc', invitationCode, decision)
+}
+
+export async function createPendingWeChatOAuthAccount(
+ invitationCode: string,
+ decision?: OAuthAdoptionDecision
+): Promise {
+ return createPendingOAuthAccount('wechat', invitationCode, decision)
+}
+
+export async function completePendingOAuthBindLogin(
+ decision?: OAuthAdoptionDecision
+): Promise {
+ const { data } = await apiClient.post(
'/auth/oauth/pending/exchange',
serializeOAuthAdoptionDecision(decision)
)
return data
}
+export async function exchangePendingOAuthCompletion(
+ decision?: OAuthAdoptionDecision
+): Promise {
+ return completePendingOAuthBindLogin(decision)
+}
+
export const authAPI = {
login,
login2FA,
@@ -498,6 +553,13 @@ export const authAPI = {
resetPassword,
refreshToken,
revokeAllSessions,
+ getPendingOAuthBindLoginKind,
+ isPendingOAuthCreateAccountRequired,
+ hasPendingOAuthSuggestedProfile,
+ completePendingOAuthBindLogin,
+ createPendingLinuxDoOAuthAccount,
+ createPendingOIDCOAuthAccount,
+ createPendingWeChatOAuthAccount,
exchangePendingOAuthCompletion,
completeLinuxDoOAuthRegistration,
completeOIDCOAuthRegistration,
diff --git a/frontend/src/stores/app.ts b/frontend/src/stores/app.ts
index 1b1af87b..a8e03a51 100644
--- a/frontend/src/stores/app.ts
+++ b/frontend/src/stores/app.ts
@@ -316,6 +316,7 @@ export const useAppStore = defineStore('app', () => {
return {
registration_enabled: false,
email_verify_enabled: false,
+ force_email_on_third_party_signup: false,
registration_email_suffix_whitelist: [],
promo_code_enabled: true,
password_reset_enabled: false,
diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts
index a19d6c26..a4b2277b 100644
--- a/frontend/src/types/index.ts
+++ b/frontend/src/types/index.ts
@@ -142,6 +142,7 @@ export interface CustomEndpoint {
export interface PublicSettings {
registration_enabled: boolean
email_verify_enabled: boolean
+ force_email_on_third_party_signup: boolean
registration_email_suffix_whitelist: string[]
promo_code_enabled: boolean
password_reset_enabled: boolean
diff --git a/frontend/src/views/auth/LinuxDoCallbackView.vue b/frontend/src/views/auth/LinuxDoCallbackView.vue
index 6dc8f242..00b73868 100644
--- a/frontend/src/views/auth/LinuxDoCallbackView.vue
+++ b/frontend/src/views/auth/LinuxDoCallbackView.vue
@@ -11,7 +11,10 @@
-
+
@@ -127,11 +214,12 @@
diff --git a/frontend/src/components/admin/monitor/MonitorFiltersBar.vue b/frontend/src/components/admin/monitor/MonitorFiltersBar.vue
new file mode 100644
index 00000000..ebb06a68
--- /dev/null
+++ b/frontend/src/components/admin/monitor/MonitorFiltersBar.vue
@@ -0,0 +1,95 @@
+
+
+
+
+
+
+
+
+
+
+
+
+ {{ t('admin.channelMonitor.createButton') }}
+
+
+
+
+
+
diff --git a/frontend/src/components/admin/monitor/MonitorFormDialog.vue b/frontend/src/components/admin/monitor/MonitorFormDialog.vue
new file mode 100644
index 00000000..920c3f79
--- /dev/null
+++ b/frontend/src/components/admin/monitor/MonitorFormDialog.vue
@@ -0,0 +1,297 @@
+
+
+
+
+
+
+
+ {{ t('common.cancel') }}
+
+
+ {{ submitting
+ ? t('common.submitting')
+ : editing ? t('common.update') : t('common.create') }}
+
+
+
+
+
+
+
+
+
diff --git a/frontend/src/components/admin/monitor/MonitorKeyPickerDialog.vue b/frontend/src/components/admin/monitor/MonitorKeyPickerDialog.vue
new file mode 100644
index 00000000..eefe4073
--- /dev/null
+++ b/frontend/src/components/admin/monitor/MonitorKeyPickerDialog.vue
@@ -0,0 +1,64 @@
+
+
+
+
+ {{ t('admin.channelMonitor.form.selectKeyHint') }}
+
+
+ {{ t('common.loading') }}
+
+
+ {{ t('admin.channelMonitor.form.noActiveKey') }}
+
+
+
+ {{ k.name }}
+ {{ maskKey(k.key) }}
+
+
+
+
+
+
+ {{ t('common.cancel') }}
+
+
+
+
+
+
+
diff --git a/frontend/src/components/admin/monitor/MonitorPrimaryModelCell.vue b/frontend/src/components/admin/monitor/MonitorPrimaryModelCell.vue
new file mode 100644
index 00000000..eccec828
--- /dev/null
+++ b/frontend/src/components/admin/monitor/MonitorPrimaryModelCell.vue
@@ -0,0 +1,71 @@
+
+
+
{{ row.primary_model }}
+
+
+
+ {{ statusLabel(row.primary_status) }}
+
+
+
+
+ {{ row.primary_model }}
+
+ {{ statusLabel(row.primary_status) }}
+
+
+
+ {{ t('monitorCommon.extraModelsEmpty') }}
+
+
+
+ {{ t('monitorCommon.extraModelsHeader') }}
+
+
+
+
+ {{ t('admin.channelMonitor.columns.primaryModel') }}
+ {{ t('admin.channelMonitor.columns.actions') }}
+ {{ t('admin.channelMonitor.columns.latency') }}
+
+
+
+
+ {{ m.model }}
+
+
+ {{ statusLabel(m.status) }}
+
+
+ {{ formatLatency(m.latency_ms) }}
+
+
+
+
+
+
+
+
+
+
diff --git a/frontend/src/components/admin/monitor/MonitorRunResultDialog.vue b/frontend/src/components/admin/monitor/MonitorRunResultDialog.vue
new file mode 100644
index 00000000..02fa6e8d
--- /dev/null
+++ b/frontend/src/components/admin/monitor/MonitorRunResultDialog.vue
@@ -0,0 +1,56 @@
+
+
+
+
+
+ {{ r.model }}
+ {{ r.message }}
+
+
+
+ {{ statusLabel(r.status) }}
+
+ {{ formatLatency(r.latency_ms) }} ms
+
+
+
+
+
+
+ {{ t('common.close') }}
+
+
+
+
+
+
+
diff --git a/frontend/src/components/layout/AppSidebar.vue b/frontend/src/components/layout/AppSidebar.vue
index 92dcc519..23d0f4e9 100644
--- a/frontend/src/components/layout/AppSidebar.vue
+++ b/frontend/src/components/layout/AppSidebar.vue
@@ -38,7 +38,7 @@
'sidebar-link-collapsed': sidebarCollapsed
}"
:title="sidebarCollapsed ? item.label : undefined"
- @click="sidebarCollapsed ? undefined : toggleGroup(item)"
+ @click="handleGroupClick(item)"
>
import { computed, h, onMounted, ref, watch } from 'vue'
-import { useRoute } from 'vue-router'
+import { useRoute, useRouter } from 'vue-router'
import { useI18n } from 'vue-i18n'
import { useAdminSettingsStore, useAppStore, useAuthStore, useOnboardingStore } from '@/stores'
import VersionBadge from '@/components/common/VersionBadge.vue'
@@ -194,11 +194,17 @@ interface NavItem {
iconSvg?: string
hideInSimpleMode?: boolean
children?: NavItem[]
+ /**
+ * When true, the parent item only toggles the expand/collapse state and
+ * does NOT navigate to its `path`. The `path` is purely a stable key.
+ */
+ expandOnly?: boolean
}
const { t } = useI18n()
const route = useRoute()
+const router = useRouter()
const appStore = useAppStore()
const authStore = useAuthStore()
const onboardingStore = useOnboardingStore()
@@ -549,6 +555,41 @@ const ChevronDoubleRightIcon = {
)
}
+const SignalIcon = {
+ render: () =>
+ h(
+ 'svg',
+ { fill: 'none', viewBox: '0 0 24 24', stroke: 'currentColor', 'stroke-width': '1.5' },
+ [
+ h('path', {
+ 'stroke-linecap': 'round',
+ 'stroke-linejoin': 'round',
+ d: 'M9.348 14.651a3.75 3.75 0 010-5.303m5.304 0a3.75 3.75 0 010 5.303m-7.425 2.122a6.75 6.75 0 010-9.546m9.546 0a6.75 6.75 0 010 9.546M5.106 18.894c-3.808-3.807-3.808-9.98 0-13.788m13.788 0c3.808 3.807 3.808 9.98 0 13.788M12 12h.008v.008H12V12zm.375 0a.375.375 0 11-.75 0 .375.375 0 01.75 0z'
+ })
+ ]
+ )
+}
+
+const PriceTagIcon = {
+ render: () =>
+ h(
+ 'svg',
+ { fill: 'none', viewBox: '0 0 24 24', stroke: 'currentColor', 'stroke-width': '1.5' },
+ [
+ h('path', {
+ 'stroke-linecap': 'round',
+ 'stroke-linejoin': 'round',
+ d: 'M9.568 3H5.25A2.25 2.25 0 003 5.25v4.318c0 .597.237 1.17.659 1.591l9.581 9.581c.699.699 1.78.872 2.607.33a18.095 18.095 0 005.223-5.223c.542-.827.369-1.908-.33-2.607L11.16 3.66A2.25 2.25 0 009.568 3z'
+ }),
+ h('path', {
+ 'stroke-linecap': 'round',
+ 'stroke-linejoin': 'round',
+ d: 'M6 6h.008v.008H6V6z'
+ })
+ ]
+ )
+}
+
const ChevronDownIcon = {
render: () =>
h(
@@ -570,6 +611,7 @@ const userNavItems = computed((): NavItem[] => {
{ path: '/dashboard', label: t('nav.dashboard'), icon: DashboardIcon },
{ path: '/keys', label: t('nav.apiKeys'), icon: KeyIcon },
{ path: '/usage', label: t('nav.usage'), icon: ChartIcon, hideInSimpleMode: true },
+ { path: '/monitor', label: t('nav.channelStatus'), icon: SignalIcon },
{ path: '/subscriptions', label: t('nav.mySubscriptions'), icon: CreditCardIcon, hideInSimpleMode: true },
...(appStore.cachedPublicSettings?.payment_enabled
? [
@@ -608,6 +650,7 @@ const personalNavItems = computed((): NavItem[] => {
const items: NavItem[] = [
{ path: '/keys', label: t('nav.apiKeys'), icon: KeyIcon },
{ path: '/usage', label: t('nav.usage'), icon: ChartIcon, hideInSimpleMode: true },
+ { path: '/monitor', label: t('nav.channelStatus'), icon: SignalIcon },
{ path: '/subscriptions', label: t('nav.mySubscriptions'), icon: CreditCardIcon, hideInSimpleMode: true },
...(appStore.cachedPublicSettings?.payment_enabled
? [
@@ -664,7 +707,17 @@ const adminNavItems = computed((): NavItem[] => {
: []),
{ path: '/admin/users', label: t('nav.users'), icon: UsersIcon, hideInSimpleMode: true },
{ path: '/admin/groups', label: t('nav.groups'), icon: FolderIcon, hideInSimpleMode: true },
- { path: '/admin/channels', label: t('nav.channels', '渠道管理'), icon: ChannelIcon, hideInSimpleMode: true },
+ {
+ path: '/admin/channels',
+ label: t('nav.channelManagement'),
+ icon: ChannelIcon,
+ hideInSimpleMode: true,
+ expandOnly: true,
+ children: [
+ { path: '/admin/channels/pricing', label: t('nav.channelPricing'), icon: PriceTagIcon },
+ { path: '/admin/channels/monitor', label: t('nav.channelMonitor'), icon: SignalIcon },
+ ],
+ },
{ path: '/admin/subscriptions', label: t('nav.subscriptions'), icon: CreditCardIcon, hideInSimpleMode: true },
{ path: '/admin/accounts', label: t('nav.accounts'), icon: GlobeIcon },
{ path: '/admin/announcements', label: t('nav.announcements'), icon: BellIcon },
@@ -678,6 +731,7 @@ const adminNavItems = computed((): NavItem[] => {
label: t('nav.orderManagement'),
icon: OrderIcon,
hideInSimpleMode: true,
+ expandOnly: true,
children: [
{ path: '/admin/orders/dashboard', label: t('nav.paymentDashboard'), icon: ChartIcon },
{ path: '/admin/orders', label: t('nav.orderManagement'), icon: OrderIcon },
@@ -764,6 +818,28 @@ function toggleGroup(item: NavItem) {
}
}
+/**
+ * Click handler for collapsible parent items.
+ * - When sidebar is collapsed: do nothing (children are not visible).
+ * - When `expandOnly` is true: only toggle expand state.
+ * - Otherwise (default, e.g. /admin/orders): navigate to the parent path
+ * (router-link semantics) and ensure the group is expanded.
+ */
+function handleGroupClick(item: NavItem) {
+ if (sidebarCollapsed.value) return
+ if (item.expandOnly) {
+ toggleGroup(item)
+ return
+ }
+ // Push to path and ensure expanded
+ if (route.path !== item.path) {
+ router.push(item.path)
+ }
+ if (!expandedGroups.value.has(item.path)) {
+ expandedGroups.value.add(item.path)
+ }
+}
+
// Initialize theme
const savedTheme = localStorage.getItem('theme')
if (
diff --git a/frontend/src/components/user/MonitorDetailDialog.vue b/frontend/src/components/user/MonitorDetailDialog.vue
new file mode 100644
index 00000000..564f461b
--- /dev/null
+++ b/frontend/src/components/user/MonitorDetailDialog.vue
@@ -0,0 +1,114 @@
+
+
+
+ {{ t('common.loading') }}
+
+
+ {{ t('channelStatus.detailLoadError') }}
+
+
+
+
+
+ {{ t('channelStatus.detailColumns.model') }}
+ {{ t('channelStatus.detailColumns.latestStatus') }}
+ {{ t('channelStatus.detailColumns.latestLatency') }}
+ {{ t('channelStatus.detailColumns.availability7d') }}
+ {{ t('channelStatus.detailColumns.availability15d') }}
+ {{ t('channelStatus.detailColumns.availability30d') }}
+ {{ t('channelStatus.detailColumns.avgLatency7d') }}
+
+
+
+
+ {{ m.model }}
+
+
+ {{ statusLabel(m.latest_status) }}
+
+
+ {{ formatLatency(m.latest_latency_ms) }}
+ {{ formatPercent(m.availability_7d) }}
+ {{ formatPercent(m.availability_15d) }}
+ {{ formatPercent(m.availability_30d) }}
+ {{ formatLatency(m.avg_latency_7d_ms) }}
+
+
+
+
+
+
+
+
+ {{ t('channelStatus.closeDetail') }}
+
+
+
+
+
+
+
diff --git a/frontend/src/components/user/MonitorPrimaryModelCell.vue b/frontend/src/components/user/MonitorPrimaryModelCell.vue
new file mode 100644
index 00000000..32620b2a
--- /dev/null
+++ b/frontend/src/components/user/MonitorPrimaryModelCell.vue
@@ -0,0 +1,71 @@
+
+
+
{{ row.primary_model }}
+
+
+
+ {{ statusLabel(row.primary_status) }}
+
+
+
+
+ {{ row.primary_model }}
+
+ {{ statusLabel(row.primary_status) }}
+
+
+
+ {{ t('monitorCommon.extraModelsEmpty') }}
+
+
+
+ {{ t('monitorCommon.extraModelsHeader') }}
+
+
+
+
+ {{ t('channelStatus.detailColumns.model') }}
+ {{ t('channelStatus.detailColumns.latestStatus') }}
+ {{ t('channelStatus.detailColumns.latestLatency') }}
+
+
+
+
+ {{ m.model }}
+
+
+ {{ statusLabel(m.status) }}
+
+
+ {{ formatLatency(m.latency_ms) }}
+
+
+
+
+
+
+
+
+
+
diff --git a/frontend/src/composables/useChannelMonitorFormat.ts b/frontend/src/composables/useChannelMonitorFormat.ts
new file mode 100644
index 00000000..fbb310fa
--- /dev/null
+++ b/frontend/src/composables/useChannelMonitorFormat.ts
@@ -0,0 +1,97 @@
+/**
+ * Shared formatting helpers for channel monitor views (admin + user).
+ *
+ * Centralises:
+ * - status / provider label + badge class lookups
+ * - latency / availability / percent number formatting
+ *
+ * i18n keys live under `monitorCommon.*` so admin and user views share the
+ * same translation source.
+ */
+
+import { useI18n } from 'vue-i18n'
+import type { MonitorStatus, Provider } from '@/api/admin/channelMonitor'
+import {
+ PROVIDER_OPENAI,
+ PROVIDER_ANTHROPIC,
+ PROVIDER_GEMINI,
+ STATUS_OPERATIONAL,
+ STATUS_DEGRADED,
+ STATUS_FAILED,
+ STATUS_ERROR,
+} from '@/constants/channelMonitor'
+
+const NEUTRAL_BADGE = 'bg-gray-100 text-gray-800 dark:bg-dark-700 dark:text-gray-300'
+
+export interface AvailabilityRow {
+ primary_status: MonitorStatus | ''
+ availability_7d: number | null | undefined
+}
+
+export function useChannelMonitorFormat() {
+ const { t } = useI18n()
+
+ function statusLabel(s: MonitorStatus | ''): string {
+ if (!s) return t('monitorCommon.status.unknown')
+ return t(`monitorCommon.status.${s}`)
+ }
+
+ function statusBadgeClass(s: MonitorStatus | ''): string {
+ switch (s) {
+ case STATUS_OPERATIONAL:
+ return 'bg-green-100 text-green-800 dark:bg-green-900/30 dark:text-green-300'
+ case STATUS_DEGRADED:
+ return 'bg-yellow-100 text-yellow-800 dark:bg-yellow-900/30 dark:text-yellow-300'
+ case STATUS_FAILED:
+ return 'bg-red-100 text-red-800 dark:bg-red-900/30 dark:text-red-300'
+ case STATUS_ERROR:
+ default:
+ return NEUTRAL_BADGE
+ }
+ }
+
+ function providerLabel(p: Provider | string): string {
+ if (p === PROVIDER_OPENAI || p === PROVIDER_ANTHROPIC || p === PROVIDER_GEMINI) {
+ return t(`monitorCommon.providers.${p}`)
+ }
+ return p || '-'
+ }
+
+ function providerBadgeClass(p: Provider | string): string {
+ switch (p) {
+ case PROVIDER_OPENAI:
+ return 'bg-green-100 text-green-800 dark:bg-green-900/30 dark:text-green-300'
+ case PROVIDER_ANTHROPIC:
+ return 'bg-orange-100 text-orange-800 dark:bg-orange-900/30 dark:text-orange-300'
+ case PROVIDER_GEMINI:
+ return 'bg-blue-100 text-blue-800 dark:bg-blue-900/30 dark:text-blue-300'
+ default:
+ return NEUTRAL_BADGE
+ }
+ }
+
+ function formatLatency(ms: number | null | undefined): string {
+ if (ms == null) return t('monitorCommon.latencyEmpty')
+ return String(Math.round(ms))
+ }
+
+ function formatPercent(v: number | null | undefined): string {
+ if (v == null || Number.isNaN(v)) return '-'
+ return `${v.toFixed(2)}%`
+ }
+
+ function formatAvailability(row: AvailabilityRow): string {
+ if (!row.primary_status) return '-'
+ return formatPercent(row.availability_7d)
+ }
+
+ return {
+ statusLabel,
+ statusBadgeClass,
+ providerLabel,
+ providerBadgeClass,
+ formatLatency,
+ formatPercent,
+ formatAvailability,
+ }
+}
diff --git a/frontend/src/constants/channelMonitor.ts b/frontend/src/constants/channelMonitor.ts
new file mode 100644
index 00000000..7523a878
--- /dev/null
+++ b/frontend/src/constants/channelMonitor.ts
@@ -0,0 +1,35 @@
+/**
+ * Channel monitor shared constants.
+ *
+ * Single source of truth for provider/status string values used by both the
+ * admin (`views/admin/ChannelMonitorView.vue`) and user-facing
+ * (`views/user/ChannelStatusView.vue`) screens, plus the shared composable
+ * `useChannelMonitorFormat`.
+ */
+
+import type { Provider, MonitorStatus } from '@/api/admin/channelMonitor'
+
+export const PROVIDER_OPENAI: Provider = 'openai'
+export const PROVIDER_ANTHROPIC: Provider = 'anthropic'
+export const PROVIDER_GEMINI: Provider = 'gemini'
+
+export const PROVIDERS: readonly Provider[] = [
+ PROVIDER_OPENAI,
+ PROVIDER_ANTHROPIC,
+ PROVIDER_GEMINI,
+]
+
+export const STATUS_OPERATIONAL: MonitorStatus = 'operational'
+export const STATUS_DEGRADED: MonitorStatus = 'degraded'
+export const STATUS_FAILED: MonitorStatus = 'failed'
+export const STATUS_ERROR: MonitorStatus = 'error'
+
+export const MONITOR_STATUSES: readonly MonitorStatus[] = [
+ STATUS_OPERATIONAL,
+ STATUS_DEGRADED,
+ STATUS_FAILED,
+ STATUS_ERROR,
+]
+
+/** Default polling interval (seconds) for new monitors. */
+export const DEFAULT_INTERVAL_SECONDS = 60
diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts
index 1b7ffa81..32fbce19 100644
--- a/frontend/src/i18n/locales/en.ts
+++ b/frontend/src/i18n/locales/en.ts
@@ -245,6 +245,7 @@ export default {
// Common
common: {
loading: 'Loading...',
+ submitting: 'Submitting...',
justNow: 'just now',
save: 'Save',
saved: 'Saved successfully',
@@ -363,7 +364,11 @@ export default {
orderManagement: 'Orders',
paymentDashboard: 'Payment Dashboard',
paymentConfig: 'Payment Config',
- paymentPlans: 'Plans'
+ paymentPlans: 'Plans',
+ channelManagement: 'Channels',
+ channelPricing: 'Channel Pricing',
+ channelMonitor: 'Channel Monitor',
+ channelStatus: 'Channel Status',
},
// Auth
@@ -846,6 +851,58 @@ export default {
userAgent: 'User-Agent'
},
+ // Shared keys for channel monitor (admin + user views)
+ monitorCommon: {
+ status: {
+ operational: 'Operational',
+ degraded: 'Degraded',
+ failed: 'Failed',
+ error: 'Error',
+ unknown: '-'
+ },
+ providers: {
+ openai: 'OpenAI',
+ anthropic: 'Anthropic',
+ gemini: 'Gemini'
+ },
+ extraModelsHeader: 'Extra Models',
+ extraModelsEmpty: 'No extra models',
+ latencyEmpty: '-'
+ },
+
+ // Channel Status (user-facing read-only view)
+ channelStatus: {
+ title: 'Channel Status',
+ description: 'Inspect channel availability, latency and recent status',
+ searchPlaceholder: 'Search channels...',
+ allProviders: 'All Providers',
+ loadError: 'Failed to load channel status',
+ detailLoadError: 'Failed to load channel detail',
+ detailTitle: 'Channel Detail',
+ closeDetail: 'Close',
+ columns: {
+ name: 'Name',
+ provider: 'Provider',
+ groupName: 'Group',
+ primaryModel: 'Primary Model',
+ availability7d: '7d Availability',
+ latency: 'Latency (ms)'
+ },
+ detailColumns: {
+ model: 'Model',
+ latestStatus: 'Latest Status',
+ latestLatency: 'Latest Latency (ms)',
+ availability7d: '7d Availability',
+ availability15d: '15d Availability',
+ availability30d: '30d Availability',
+ avgLatency7d: '7d Avg Latency (ms)'
+ },
+ empty: {
+ title: 'No channels available',
+ description: 'No monitored channels have been configured yet.'
+ }
+ },
+
// Redeem
redeem: {
title: 'Redeem Code',
@@ -2014,6 +2071,69 @@ export default {
}
},
+ // Channel Monitor
+ channelMonitor: {
+ title: 'Channel Monitor',
+ description: 'Monitor channel availability, latency and status',
+ searchPlaceholder: 'Search monitor name...',
+ allProviders: 'All Providers',
+ allStatus: 'All Status',
+ enabledFilter: 'Enabled',
+ onlyEnabled: 'Enabled only',
+ onlyDisabled: 'Disabled only',
+ createButton: 'Create Monitor',
+ createTitle: 'Create Channel Monitor',
+ editTitle: 'Edit Channel Monitor',
+ runNow: 'Run Now',
+ runSuccess: 'Check completed',
+ runFailed: 'Check failed',
+ apiKeyDecryptFailed: 'API Key decryption failed. Please re-edit this monitor with a fresh key.',
+ createSuccess: 'Monitor created',
+ updateSuccess: 'Monitor updated',
+ deleteSuccess: 'Monitor deleted',
+ loadError: 'Failed to load monitors',
+ deleteConfirm: 'Are you sure you want to delete monitor "{name}"? This action cannot be undone.',
+ nameRequired: 'Please enter a monitor name',
+ primaryModelRequired: 'Please enter a primary model',
+ columns: {
+ name: 'Name',
+ provider: 'Provider',
+ primaryModel: 'Primary Model',
+ availability7d: '7d Availability',
+ latency: 'Latency (ms)',
+ enabled: 'Enabled',
+ actions: 'Actions'
+ },
+ form: {
+ name: 'Name',
+ namePlaceholder: 'Enter monitor name',
+ provider: 'Provider',
+ endpoint: 'Endpoint',
+ endpointPlaceholder: 'https://api.example.com',
+ useCurrentDomain: 'Use current service',
+ apiKey: 'API Key',
+ apiKeyPlaceholder: 'Enter API Key',
+ apiKeyEditPlaceholder: 'Leave blank to keep current key',
+ useMyKey: 'Use my key',
+ selectKeyTitle: 'Select my API Key',
+ selectKeyHint: 'Only your active, non-expired keys are listed.',
+ noActiveKey: 'No active API keys available',
+ primaryModel: 'Primary Model',
+ primaryModelPlaceholder: 'gpt-4o-mini',
+ extraModels: 'Extra Models',
+ extraModelsPlaceholder: 'Press Enter to add extra model',
+ groupName: 'Group Name',
+ groupNamePlaceholder: 'Optional, used to group rows in user view',
+ intervalSeconds: 'Interval (seconds)',
+ intervalSecondsHint: 'Range: 15 - 3600 seconds',
+ enabled: 'Enable monitor',
+ kindRequired: 'Please select a provider'
+ },
+ runResultTitle: 'Check Result',
+ noMonitorsYet: 'No monitors yet',
+ createFirstMonitor: 'Create your first monitor to track channel availability'
+ },
+
// Subscriptions
subscriptions: {
title: 'Subscription Management',
diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts
index beb6841f..dd3af363 100644
--- a/frontend/src/i18n/locales/zh.ts
+++ b/frontend/src/i18n/locales/zh.ts
@@ -245,6 +245,7 @@ export default {
// Common
common: {
loading: '加载中...',
+ submitting: '提交中...',
justNow: '刚刚',
save: '保存',
saved: '保存成功',
@@ -363,7 +364,11 @@ export default {
orderManagement: '订单管理',
paymentDashboard: '支付概览',
paymentConfig: '支付配置',
- paymentPlans: '订阅套餐'
+ paymentPlans: '订阅套餐',
+ channelManagement: '渠道管理',
+ channelPricing: '渠道定价',
+ channelMonitor: '渠道监控',
+ channelStatus: '渠道状态',
},
// Auth
@@ -850,6 +855,58 @@ export default {
userAgent: 'User-Agent'
},
+ // Shared keys for channel monitor (admin + user views)
+ monitorCommon: {
+ status: {
+ operational: '正常',
+ degraded: '降级',
+ failed: '失败',
+ error: '错误',
+ unknown: '-'
+ },
+ providers: {
+ openai: 'OpenAI',
+ anthropic: 'Anthropic',
+ gemini: 'Gemini'
+ },
+ extraModelsHeader: '附加模型',
+ extraModelsEmpty: '无附加模型',
+ latencyEmpty: '-'
+ },
+
+ // Channel Status (user-facing read-only view)
+ channelStatus: {
+ title: '渠道状态',
+ description: '查看渠道可用性、延迟和近期状态',
+ searchPlaceholder: '搜索渠道...',
+ allProviders: '全部供应商',
+ loadError: '加载渠道状态失败',
+ detailLoadError: '加载渠道详情失败',
+ detailTitle: '渠道详情',
+ closeDetail: '关闭',
+ columns: {
+ name: '名称',
+ provider: '供应商',
+ groupName: '分组',
+ primaryModel: '主模型',
+ availability7d: '7 天可用率',
+ latency: '延迟 (ms)'
+ },
+ detailColumns: {
+ model: '模型',
+ latestStatus: '最新状态',
+ latestLatency: '最新延迟 (ms)',
+ availability7d: '7 天可用率',
+ availability15d: '15 天可用率',
+ availability30d: '30 天可用率',
+ avgLatency7d: '7 天平均延迟 (ms)'
+ },
+ empty: {
+ title: '暂无可显示的渠道',
+ description: '管理员尚未配置可监控的渠道。'
+ }
+ },
+
// Redeem
redeem: {
title: '兑换码',
@@ -2093,6 +2150,69 @@ export default {
}
},
+ // Channel Monitor
+ channelMonitor: {
+ title: '渠道监控',
+ description: '监测各渠道的可用性、延迟和状态',
+ searchPlaceholder: '搜索监控名称...',
+ allProviders: '全部供应商',
+ allStatus: '全部状态',
+ enabledFilter: '启用状态',
+ onlyEnabled: '仅启用',
+ onlyDisabled: '仅禁用',
+ createButton: '新增监控',
+ createTitle: '新增渠道监控',
+ editTitle: '编辑渠道监控',
+ runNow: '立即检测',
+ runSuccess: '检测完成',
+ runFailed: '检测失败',
+ apiKeyDecryptFailed: 'API Key 解密失败,请重新编辑该监控并填入新的 Key',
+ createSuccess: '监控创建成功',
+ updateSuccess: '监控更新成功',
+ deleteSuccess: '监控删除成功',
+ loadError: '加载监控列表失败',
+ deleteConfirm: '确定要删除监控「{name}」吗?此操作不可撤销。',
+ nameRequired: '请输入监控名称',
+ primaryModelRequired: '请输入主模型',
+ columns: {
+ name: '名称',
+ provider: '供应商',
+ primaryModel: '主模型',
+ availability7d: '7 天可用率',
+ latency: '延迟 (ms)',
+ enabled: '启用',
+ actions: '操作'
+ },
+ form: {
+ name: '名称',
+ namePlaceholder: '输入监控名称',
+ provider: '供应商',
+ endpoint: '上游地址',
+ endpointPlaceholder: 'https://api.example.com',
+ useCurrentDomain: '使用当前服务',
+ apiKey: 'API Key',
+ apiKeyPlaceholder: '请输入 API Key',
+ apiKeyEditPlaceholder: '留空表示不修改',
+ useMyKey: '使用我的 Key',
+ selectKeyTitle: '选择我的 API Key',
+ selectKeyHint: '仅显示当前账号下处于「启用」状态且未过期的 Key。',
+ noActiveKey: '没有可用的启用状态 Key',
+ primaryModel: '主模型',
+ primaryModelPlaceholder: 'gpt-4o-mini',
+ extraModels: '附加模型',
+ extraModelsPlaceholder: '回车添加附加模型',
+ groupName: '分组名称',
+ groupNamePlaceholder: '可选,用于在用户视图中聚合显示',
+ intervalSeconds: '检测间隔 (秒)',
+ intervalSecondsHint: '范围:15 - 3600 秒',
+ enabled: '启用监控',
+ kindRequired: '请选择供应商'
+ },
+ runResultTitle: '检测结果',
+ noMonitorsYet: '暂无监控',
+ createFirstMonitor: '创建第一个监控来跟踪渠道可用性'
+ },
+
// Subscriptions Management
subscriptions: {
title: '订阅管理',
diff --git a/frontend/src/router/index.ts b/frontend/src/router/index.ts
index b97ccb5d..491a984d 100644
--- a/frontend/src/router/index.ts
+++ b/frontend/src/router/index.ts
@@ -360,6 +360,10 @@ const routes: RouteRecordRaw[] = [
},
{
path: '/admin/channels',
+ redirect: '/admin/channels/pricing'
+ },
+ {
+ path: '/admin/channels/pricing',
name: 'AdminChannels',
component: () => import('@/views/admin/ChannelsView.vue'),
meta: {
@@ -370,6 +374,29 @@ const routes: RouteRecordRaw[] = [
descriptionKey: 'admin.channels.description'
}
},
+ {
+ path: '/admin/channels/monitor',
+ name: 'AdminChannelMonitor',
+ component: () => import('@/views/admin/ChannelMonitorView.vue'),
+ meta: {
+ requiresAuth: true,
+ requiresAdmin: true,
+ title: 'Channel Monitor',
+ titleKey: 'admin.channelMonitor.title',
+ descriptionKey: 'admin.channelMonitor.description'
+ }
+ },
+ {
+ path: '/monitor',
+ name: 'ChannelStatus',
+ component: () => import('@/views/user/ChannelStatusView.vue'),
+ meta: {
+ requiresAuth: true,
+ requiresAdmin: false,
+ title: 'Channel Status',
+ titleKey: 'nav.channelStatus'
+ }
+ },
{
path: '/admin/subscriptions',
name: 'AdminSubscriptions',
diff --git a/frontend/src/views/admin/ChannelMonitorView.vue b/frontend/src/views/admin/ChannelMonitorView.vue
new file mode 100644
index 00000000..8f0a1e2f
--- /dev/null
+++ b/frontend/src/views/admin/ChannelMonitorView.vue
@@ -0,0 +1,295 @@
+
+
+
+
+
+
+
+
+
+
+
+ {{ value }}
+
+
+
+
+
+
+
+
+ {{ providerLabel(row.provider) }}
+
+
+
+
+
+
+
+
+ {{ formatAvailability(row) }}
+
+
+
+ {{ formatLatency(row.primary_latency_ms) }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/frontend/src/views/user/ChannelStatusView.vue b/frontend/src/views/user/ChannelStatusView.vue
new file mode 100644
index 00000000..9f5fe8d1
--- /dev/null
+++ b/frontend/src/views/user/ChannelStatusView.vue
@@ -0,0 +1,208 @@
+
+
+
+
+
+
+
+
+
+
+
+ {{ row.name }}
+
+
+
+
+
+ {{ providerLabel(row.provider) }}
+
+
+
+
+ {{ value || '-' }}
+
+
+
+
+
+
+
+
+ {{ formatAvailability(row) }}
+
+
+
+
+
+ {{ formatLatency(row.primary_latency_ms) }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
--
GitLab
From 58b2cc380fefc180d96c3baf10d3214026c1341c Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Mon, 20 Apr 2026 20:22:00 +0800
Subject: [PATCH 046/261] test: harden payment result resume flow
---
frontend/src/views/user/PaymentResultView.vue | 50 +++----
.../user/__tests__/PaymentResultView.spec.ts | 132 ++++++++++++++++++
2 files changed, 157 insertions(+), 25 deletions(-)
create mode 100644 frontend/src/views/user/__tests__/PaymentResultView.spec.ts
diff --git a/frontend/src/views/user/PaymentResultView.vue b/frontend/src/views/user/PaymentResultView.vue
index 6431ddf6..e1db3ce2 100644
--- a/frontend/src/views/user/PaymentResultView.vue
+++ b/frontend/src/views/user/PaymentResultView.vue
@@ -94,6 +94,7 @@ import { ref, computed, onMounted } from 'vue'
import { useI18n } from 'vue-i18n'
import { useRoute, useRouter } from 'vue-router'
import OrderStatusBadge from '@/components/payment/OrderStatusBadge.vue'
+import { PAYMENT_RECOVERY_STORAGE_KEY, readPaymentRecoverySnapshot } from '@/components/payment/paymentFlow'
import { usePaymentStore } from '@/stores/payment'
import { paymentAPI } from '@/api/payment'
import type { PaymentOrder } from '@/types/payment'
@@ -129,46 +130,46 @@ const feeAmount = computed(() => {
})
const isSuccess = computed(() => {
- // Always prioritize actual order status from backend
- if (order.value) {
- return SUCCESS_STATUSES.has(order.value.status)
- }
- // Fallback only when order not loaded
- if (route.query.status === 'success') return true
- if (route.query.trade_status === 'TRADE_SUCCESS') return true
- return false
+ return !!order.value && SUCCESS_STATUSES.has(order.value.status)
})
-/** Extract numeric order ID from out_trade_no like "sub2_46" → 46 */
-function parseOutTradeNo(outTradeNo: string): number {
- const match = outTradeNo.match(/_(\d+)$/)
- return match ? Number(match[1]) : 0
-}
-
onMounted(async () => {
- // Try order_id first (internal navigation from QRCode/Stripe pages)
+ const resumeToken = typeof route.query.resume_token === 'string'
+ ? route.query.resume_token
+ : ''
let orderId = Number(route.query.order_id) || 0
const outTradeNo = String(route.query.out_trade_no || '')
- // Fallback: EasyPay return URL with out_trade_no
- if (!orderId && outTradeNo) {
- orderId = parseOutTradeNo(outTradeNo)
- // Store return info for display when order lookup fails
+ if (!orderId && resumeToken && typeof window !== 'undefined') {
+ const restored = readPaymentRecoverySnapshot(
+ window.localStorage.getItem(PAYMENT_RECOVERY_STORAGE_KEY),
+ { resumeToken },
+ )
+ if (restored?.orderId) {
+ orderId = restored.orderId
+ }
+ }
+
+ if (orderId) {
+ try {
+ order.value = await paymentStore.pollOrderStatus(orderId)
+ } catch (_err: unknown) {
+ // Order lookup failed, will try legacy fallback below when possible.
+ }
+ }
+
+ if (!order.value && outTradeNo) {
returnInfo.value = {
outTradeNo,
money: String(route.query.money || ''),
type: String(route.query.type || ''),
tradeStatus: String(route.query.trade_status || ''),
}
- }
- // Verify payment via public endpoint (works without login)
- if (outTradeNo) {
try {
const result = await paymentAPI.verifyOrderPublic(outTradeNo)
order.value = result.data
} catch (_err: unknown) {
- // Public verify failed, try authenticated endpoint if logged in
try {
const result = await paymentAPI.verifyOrder(outTradeNo)
order.value = result.data
@@ -176,12 +177,11 @@ onMounted(async () => {
}
}
- // Normal order lookup by ID (if verify didn't load the order)
if (!order.value && orderId) {
try {
order.value = await paymentStore.pollOrderStatus(orderId)
} catch (_err: unknown) {
- // Order lookup failed, will show returnInfo fallback
+ // Order lookup failed, will show returnInfo fallback.
}
}
loading.value = false
diff --git a/frontend/src/views/user/__tests__/PaymentResultView.spec.ts b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts
new file mode 100644
index 00000000..b06217ab
--- /dev/null
+++ b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts
@@ -0,0 +1,132 @@
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+import { flushPromises, mount } from '@vue/test-utils'
+
+const routeState = vi.hoisted(() => ({
+ query: {} as Record,
+}))
+
+const routerPush = vi.hoisted(() => vi.fn())
+const pollOrderStatus = vi.hoisted(() => vi.fn())
+const verifyOrderPublic = vi.hoisted(() => vi.fn())
+const verifyOrder = vi.hoisted(() => vi.fn())
+
+vi.mock('vue-router', async () => {
+ const actual = await vi.importActual('vue-router')
+ return {
+ ...actual,
+ useRoute: () => routeState,
+ useRouter: () => ({ push: routerPush }),
+ }
+})
+
+vi.mock('vue-i18n', async () => {
+ const actual = await vi.importActual('vue-i18n')
+ return {
+ ...actual,
+ useI18n: () => ({
+ t: (key: string) => key,
+ }),
+ }
+})
+
+vi.mock('@/stores/payment', () => ({
+ usePaymentStore: () => ({
+ pollOrderStatus,
+ }),
+}))
+
+vi.mock('@/api/payment', () => ({
+ paymentAPI: {
+ verifyOrderPublic,
+ verifyOrder,
+ },
+}))
+
+import PaymentResultView from '../PaymentResultView.vue'
+import { PAYMENT_RECOVERY_STORAGE_KEY } from '@/components/payment/paymentFlow'
+
+const orderFactory = (status: string) => ({
+ id: 42,
+ user_id: 9,
+ amount: 88,
+ pay_amount: 88,
+ fee_rate: 0,
+ payment_type: 'alipay',
+ out_trade_no: 'sub2_20260420abcd1234',
+ status,
+ order_type: 'balance',
+ created_at: '2026-04-20T12:00:00Z',
+ expires_at: '2026-04-20T12:30:00Z',
+ refund_amount: 0,
+})
+
+describe('PaymentResultView', () => {
+ beforeEach(() => {
+ routeState.query = {}
+ routerPush.mockReset()
+ pollOrderStatus.mockReset()
+ verifyOrderPublic.mockReset()
+ verifyOrder.mockReset()
+ window.localStorage.clear()
+ })
+
+ it('restores order id from a matching resume token and does not trust query success flags', async () => {
+ routeState.query = {
+ resume_token: 'resume-42',
+ status: 'success',
+ }
+ window.localStorage.setItem(PAYMENT_RECOVERY_STORAGE_KEY, JSON.stringify({
+ orderId: 42,
+ amount: 88,
+ qrCode: '',
+ expiresAt: '2099-01-01T00:10:00.000Z',
+ paymentType: 'alipay',
+ payUrl: 'https://pay.example.com/session/42',
+ clientSecret: '',
+ payAmount: 88,
+ orderType: 'balance',
+ paymentMode: 'redirect',
+ resumeToken: 'resume-42',
+ createdAt: Date.UTC(2099, 0, 1, 0, 0, 0),
+ }))
+ pollOrderStatus.mockResolvedValue(orderFactory('PENDING'))
+
+ const wrapper = mount(PaymentResultView, {
+ global: {
+ stubs: {
+ OrderStatusBadge: true,
+ },
+ },
+ })
+
+ await flushPromises()
+
+ expect(pollOrderStatus).toHaveBeenCalledWith(42)
+ expect(verifyOrderPublic).not.toHaveBeenCalled()
+ expect(wrapper.text()).toContain('payment.result.failed')
+ expect(wrapper.text()).not.toContain('payment.result.success')
+ })
+
+ it('keeps legacy out_trade_no verification as a fallback when no order context is available', async () => {
+ routeState.query = {
+ out_trade_no: 'legacy-123',
+ trade_status: 'TRADE_SUCCESS',
+ }
+ verifyOrderPublic.mockResolvedValue({
+ data: orderFactory('PAID'),
+ })
+
+ const wrapper = mount(PaymentResultView, {
+ global: {
+ stubs: {
+ OrderStatusBadge: true,
+ },
+ },
+ })
+
+ await flushPromises()
+
+ expect(verifyOrderPublic).toHaveBeenCalledWith('legacy-123')
+ expect(wrapper.text()).toContain('payment.result.success')
+ })
+})
--
GitLab
From 40d4e167cd8093bc3f21d4dbe205946df9c0aead Mon Sep 17 00:00:00 2001
From: erio
Date: Mon, 20 Apr 2026 20:06:53 +0800
Subject: [PATCH 047/261] feat(payment): i18n payment error codes and label
localization
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Pairs with the backend structured payment errors (reason + metadata). The
frontend now maps reason codes to localized messages with metadata as
interpolation variables, and automatically localizes raw config-field names
(e.g. "certSerial" → "证书序列号") using the existing UI-label i18n
namespace.
- frontend/src/utils/apiError.ts
- extractApiErrorCode now prefers the string `reason` over the numeric HTTP
`code`; reason is granular enough to drive i18n lookup, HTTP code is not.
- New extractApiErrorMetadata to pull interpolation params off the error.
- New extractI18nErrorMessage(err, t, namespace, fallback): looks up
`.` in i18n and substitutes metadata. Before
substitution, `metadata.key` and `metadata.keys` (slash-joined) are
re-translated through `admin.settings.payment.field_` so users see
"缺少必填项:证书序列号" instead of "缺少必填项:certSerial".
- frontend/src/i18n/locales/{zh,en}.ts
- Add payment.errors entries for every structured reason code returned by
the backend (PAYMENT_DISABLED, INVALID_AMOUNT, TOO_MANY_PENDING,
DAILY_LIMIT_EXCEEDED, NO_AVAILABLE_INSTANCE, PAYMENT_PROVIDER_MISCONFIGURED,
WXPAY_CONFIG_MISSING_KEY / INVALID_KEY_LENGTH / INVALID_KEY, NOT_FOUND,
FORBIDDEN, CONFLICT, INVALID_ORDER_TYPE, INVALID_STATUS,
BALANCE_NOT_ENOUGH, REFUND_AMOUNT_EXCEEDED, REFUND_FAILED, and more),
with placeholders for template variables.
- 13 payment-related Vue files
- Migrate catch-block error reporting from extractApiErrorMessage to
extractI18nErrorMessage(err, t, 'payment.errors', fallback).
- Remove the ad-hoc paymentErrorMap computed in SettingsView.vue, which the
new helper supersedes (it reads i18n directly via t).
- frontend/src/components/payment/providerConfig.ts
- wxpay: publicKey and publicKeyId are now required (was optional), matching
the pubkey-only verifier direction; certSerial is already required.
This PR is drop-in safe: reason-preferring extractApiErrorCode is backward
compatible with callers that pass their own i18nMap, and error codes missing
from i18n fall back to the existing message-based path.
---
.../components/payment/PaymentQRDialog.vue | 4 +-
.../components/payment/PaymentStatusPanel.vue | 4 +-
.../payment/StripePaymentInline.vue | 8 +-
.../src/components/payment/providerConfig.ts | 4 +-
frontend/src/i18n/locales/en.ts | 26 ++++++
frontend/src/i18n/locales/zh.ts | 26 ++++++
frontend/src/utils/apiError.ts | 84 ++++++++++++++++++-
frontend/src/views/admin/SettingsView.vue | 18 ++--
.../views/admin/orders/AdminOrdersView.vue | 10 +--
.../orders/AdminPaymentDashboardView.vue | 4 +-
.../admin/orders/AdminPaymentPlansView.vue | 8 +-
frontend/src/views/user/PaymentQRCodeView.vue | 4 +-
frontend/src/views/user/PaymentView.vue | 6 +-
frontend/src/views/user/StripePaymentView.vue | 6 +-
frontend/src/views/user/StripePopupView.vue | 4 +-
frontend/src/views/user/UserOrdersView.vue | 8 +-
16 files changed, 177 insertions(+), 47 deletions(-)
diff --git a/frontend/src/components/payment/PaymentQRDialog.vue b/frontend/src/components/payment/PaymentQRDialog.vue
index db90c3b6..09d273cc 100644
--- a/frontend/src/components/payment/PaymentQRDialog.vue
+++ b/frontend/src/components/payment/PaymentQRDialog.vue
@@ -78,7 +78,7 @@ import Icon from '@/components/icons/Icon.vue'
import { usePaymentStore } from '@/stores/payment'
import { useAppStore } from '@/stores'
import { paymentAPI } from '@/api/payment'
-import { extractApiErrorMessage } from '@/utils/apiError'
+import { extractI18nErrorMessage } from '@/utils/apiError'
import { getPaymentPopupFeatures } from '@/components/payment/providerConfig'
import type { PaymentOrder } from '@/types/payment'
import QRCode from 'qrcode'
@@ -222,7 +222,7 @@ async function handleCancel() {
cleanup()
emit('close')
} catch (err: unknown) {
- appStore.showError(extractApiErrorMessage(err, t('common.error')))
+ appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error')))
} finally {
cancelling.value = false
}
diff --git a/frontend/src/components/payment/PaymentStatusPanel.vue b/frontend/src/components/payment/PaymentStatusPanel.vue
index 17541e59..53989dee 100644
--- a/frontend/src/components/payment/PaymentStatusPanel.vue
+++ b/frontend/src/components/payment/PaymentStatusPanel.vue
@@ -124,7 +124,7 @@ import { useI18n } from 'vue-i18n'
import { usePaymentStore } from '@/stores/payment'
import { useAppStore } from '@/stores'
import { paymentAPI } from '@/api/payment'
-import { extractApiErrorMessage } from '@/utils/apiError'
+import { extractI18nErrorMessage } from '@/utils/apiError'
import { getPaymentPopupFeatures } from '@/components/payment/providerConfig'
import type { PaymentOrder } from '@/types/payment'
import Icon from '@/components/icons/Icon.vue'
@@ -242,7 +242,7 @@ async function handleCancel() {
cleanup()
outcome.value = 'cancelled'
} catch (err: unknown) {
- appStore.showError(extractApiErrorMessage(err, t('common.error')))
+ appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error')))
} finally {
cancelling.value = false
}
diff --git a/frontend/src/components/payment/StripePaymentInline.vue b/frontend/src/components/payment/StripePaymentInline.vue
index 3ddff8c8..bdb0dd6b 100644
--- a/frontend/src/components/payment/StripePaymentInline.vue
+++ b/frontend/src/components/payment/StripePaymentInline.vue
@@ -67,7 +67,7 @@
import { ref, onMounted, nextTick } from 'vue'
import { useI18n } from 'vue-i18n'
import { useRouter } from 'vue-router'
-import { extractApiErrorMessage } from '@/utils/apiError'
+import { extractI18nErrorMessage } from '@/utils/apiError'
import { paymentAPI } from '@/api/payment'
import { useAppStore } from '@/stores'
import { getPaymentPopupFeatures } from '@/components/payment/providerConfig'
@@ -132,7 +132,7 @@ onMounted(async () => {
selectedType.value = event.value.type
})
} catch (err: unknown) {
- initError.value = extractApiErrorMessage(err, t('payment.stripeLoadFailed'))
+ initError.value = extractI18nErrorMessage(err, t, 'payment.errors', t('payment.stripeLoadFailed'))
} finally {
loading.value = false
}
@@ -186,7 +186,7 @@ async function handlePay() {
emit('success')
}
} catch (err: unknown) {
- error.value = extractApiErrorMessage(err, t('payment.result.failed'))
+ error.value = extractI18nErrorMessage(err, t, 'payment.errors', t('payment.result.failed'))
} finally {
submitting.value = false
}
@@ -199,7 +199,7 @@ async function handleCancel() {
await paymentAPI.cancelOrder(props.orderId)
emit('back')
} catch (err: unknown) {
- appStore.showError(extractApiErrorMessage(err, t('common.error')))
+ appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error')))
} finally {
cancelling.value = false
}
diff --git a/frontend/src/components/payment/providerConfig.ts b/frontend/src/components/payment/providerConfig.ts
index bf2d4177..f4f5acdc 100644
--- a/frontend/src/components/payment/providerConfig.ts
+++ b/frontend/src/components/payment/providerConfig.ts
@@ -99,9 +99,9 @@ export const PROVIDER_CONFIG_FIELDS: Record = {
{ key: 'mchId', label: '', sensitive: false },
{ key: 'privateKey', label: '', sensitive: true },
{ key: 'apiV3Key', label: '', sensitive: true },
+ { key: 'certSerial', label: '', sensitive: false },
{ key: 'publicKey', label: '', sensitive: true },
- { key: 'publicKeyId', label: '', sensitive: false, optional: true },
- { key: 'certSerial', label: '', sensitive: false, optional: true },
+ { key: 'publicKeyId', label: '', sensitive: false },
],
stripe: [
{ key: 'secretKey', label: '', sensitive: true },
diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts
index c0a17d96..8213cb0f 100644
--- a/frontend/src/i18n/locales/en.ts
+++ b/frontend/src/i18n/locales/en.ts
@@ -5432,7 +5432,33 @@ export default {
errors: {
tooManyPending: 'Too many pending orders (max {max}). Please complete or cancel existing orders first.',
cancelRateLimited: 'Too many cancellations. Please try again later.',
+ // Structured error codes (reason strings from backend ApplicationError)
+ PAYMENT_DISABLED: 'Payment system is disabled.',
+ USER_INACTIVE: 'Your account is disabled.',
+ BALANCE_PAYMENT_DISABLED: 'Balance recharge has been disabled.',
+ INVALID_AMOUNT: 'Invalid amount.',
+ INVALID_INPUT: 'Invalid request.',
+ PLAN_NOT_AVAILABLE: 'Plan not found or no longer available.',
+ GROUP_NOT_FOUND: 'Subscription group is no longer available.',
+ GROUP_TYPE_MISMATCH: 'Group is not a subscription type.',
+ TOO_MANY_PENDING: 'Too many pending orders (max {max}). Please complete or cancel existing orders first.',
+ DAILY_LIMIT_EXCEEDED: 'Daily recharge limit reached. Remaining: {remaining}.',
+ PAYMENT_GATEWAY_ERROR: 'Payment method is unavailable.',
+ NO_AVAILABLE_INSTANCE: 'No payment channel available right now.',
+ PAYMENT_PROVIDER_MISCONFIGURED: 'Payment provider misconfigured. Please contact an administrator.',
+ WXPAY_CONFIG_MISSING_KEY: 'WeChat Pay config missing required key: {key}.',
+ WXPAY_CONFIG_INVALID_KEY_LENGTH: 'WeChat Pay {key} length is invalid (expected {expected} bytes, got {actual}).',
+ WXPAY_CONFIG_INVALID_KEY: 'WeChat Pay {key} is malformed. Make sure you copied the full PEM content.',
PENDING_ORDERS: 'This provider has pending orders. Please wait for them to complete before making changes.',
+ CANCEL_RATE_LIMITED: 'Too many cancellations. Please try again later.',
+ NOT_FOUND: 'Order not found.',
+ FORBIDDEN: 'No permission for this order.',
+ CONFLICT: 'Order status has changed. Please refresh.',
+ INVALID_ORDER_TYPE: 'Only balance orders can request a refund.',
+ INVALID_STATUS: 'The current order status does not allow this operation.',
+ BALANCE_NOT_ENOUGH: 'Refund amount exceeds balance.',
+ REFUND_AMOUNT_EXCEEDED: 'Refund amount exceeds the recharge amount.',
+ REFUND_FAILED: 'Refund failed.',
},
stripePay: 'Pay Now',
stripeSuccessProcessing: 'Payment successful, processing your order...',
diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts
index ba9edd7f..5f936965 100644
--- a/frontend/src/i18n/locales/zh.ts
+++ b/frontend/src/i18n/locales/zh.ts
@@ -5620,7 +5620,33 @@ export default {
errors: {
tooManyPending: '待支付订单过多(最多 {max} 个),请先完成或取消现有订单',
cancelRateLimited: '取消订单过于频繁,请稍后再试',
+ // Structured error codes (reason strings from backend ApplicationError)
+ PAYMENT_DISABLED: '支付系统已关闭',
+ USER_INACTIVE: '账号已被禁用',
+ BALANCE_PAYMENT_DISABLED: '余额充值功能已关闭',
+ INVALID_AMOUNT: '金额无效',
+ INVALID_INPUT: '参数有误',
+ PLAN_NOT_AVAILABLE: '套餐不存在或已下架',
+ GROUP_NOT_FOUND: '订阅分组不可用',
+ GROUP_TYPE_MISMATCH: '分组类型不是订阅类型',
+ TOO_MANY_PENDING: '待支付订单过多(最多 {max} 个),请先完成或取消现有订单',
+ DAILY_LIMIT_EXCEEDED: '今日充值已达上限,剩余额度 {remaining}',
+ PAYMENT_GATEWAY_ERROR: '支付方式不可用',
+ NO_AVAILABLE_INSTANCE: '暂无可用的支付通道',
+ PAYMENT_PROVIDER_MISCONFIGURED: '支付通道配置错误,请联系管理员',
+ WXPAY_CONFIG_MISSING_KEY: '微信支付配置缺少必填项:{key}',
+ WXPAY_CONFIG_INVALID_KEY_LENGTH: '微信支付 {key} 长度错误,应为 {expected} 字节(实际 {actual})',
+ WXPAY_CONFIG_INVALID_KEY: '微信支付 {key} 格式错误,请确认复制了完整的 PEM 内容',
PENDING_ORDERS: '该服务商有未完成的订单,请等待订单完成后再操作',
+ CANCEL_RATE_LIMITED: '取消订单过于频繁,请稍后再试',
+ NOT_FOUND: '订单不存在',
+ FORBIDDEN: '无权限操作此订单',
+ CONFLICT: '订单状态已变更,请刷新',
+ INVALID_ORDER_TYPE: '仅余额订单可申请退款',
+ INVALID_STATUS: '当前订单状态不允许此操作',
+ BALANCE_NOT_ENOUGH: '退款金额超过余额',
+ REFUND_AMOUNT_EXCEEDED: '退款金额超过充值金额',
+ REFUND_FAILED: '退款失败',
},
stripePay: '立即支付',
stripeSuccessProcessing: '支付成功,正在处理订单...',
diff --git a/frontend/src/utils/apiError.ts b/frontend/src/utils/apiError.ts
index e1fe0c30..07a17aca 100644
--- a/frontend/src/utils/apiError.ts
+++ b/frontend/src/utils/apiError.ts
@@ -23,14 +23,96 @@ interface ApiErrorLike {
/**
* Extract the error code from an API error object.
+ *
+ * Prefers the string `reason` (e.g. "PAYMENT_PROVIDER_MISCONFIGURED") over the
+ * numeric HTTP `code`, because reason is granular enough to drive i18n lookup
+ * while HTTP code is not.
*/
export function extractApiErrorCode(err: unknown): string | undefined {
if (!err || typeof err !== 'object') return undefined
const e = err as ApiErrorLike
- const code = e.code ?? e.reason ?? e.response?.data?.code
+ const code = e.reason ?? e.code ?? e.response?.data?.code
return code != null ? String(code) : undefined
}
+/**
+ * Extract metadata (interpolation params) from an API error object.
+ * Backend errors carry `metadata` with template variables that fill i18n placeholders.
+ */
+export function extractApiErrorMetadata(err: unknown): Record | undefined {
+ if (!err || typeof err !== 'object') return undefined
+ const e = err as ApiErrorLike
+ return e.metadata
+}
+
+type TranslateFn = (key: string, params?: Record) => string
+type TranslateWithExistsFn = TranslateFn & { te?: (key: string) => boolean }
+
+/**
+ * Translate a value via i18n if a matching key exists, otherwise return the original.
+ * Example: "certSerial" → t('admin.settings.payment.field_certSerial') → "证书序列号".
+ */
+function tryTranslate(t: TranslateFn, key: string, fallback: string): string {
+ const translated = t(key)
+ if (translated === key) return fallback
+ const te = (t as TranslateWithExistsFn).te
+ if (te && !te(key)) return fallback
+ return translated
+}
+
+/**
+ * Replace raw config field names in metadata (e.g. "certSerial") with their
+ * localized UI labels (e.g. "证书序列号"), using the provider-config field i18n namespace.
+ * Handles both single `key` and `/`-joined `keys` patterns used by wxpay errors.
+ */
+function localizeMetadata(metadata: Record, t: TranslateFn): Record {
+ const out: Record = { ...metadata }
+ if (typeof out.key === 'string') {
+ out.key = tryTranslate(t, `admin.settings.payment.field_${out.key}`, out.key)
+ }
+ if (typeof out.keys === 'string') {
+ out.keys = out.keys
+ .split('/')
+ .map(k => tryTranslate(t, `admin.settings.payment.field_${k}`, k))
+ .join(' / ')
+ }
+ return out
+}
+
+/**
+ * Extract a localized error message from an API error by looking up
+ * `.` in i18n and substituting metadata as placeholders.
+ *
+ * Config-field names in metadata (`key` / `keys`) are automatically translated
+ * to their UI labels before substitution, so error messages read like
+ * "缺少必填项:证书序列号" instead of "缺少必填项:certSerial".
+ *
+ * @param err - The caught error
+ * @param t - Vue i18n translate function
+ * @param namespace- i18n key prefix, e.g. "payment.errors"
+ * @param fallback - Fallback key or plain string if no localized mapping exists
+ */
+export function extractI18nErrorMessage(
+ err: unknown,
+ t: TranslateFn,
+ namespace: string,
+ fallback: string,
+): string {
+ const code = extractApiErrorCode(err)
+ if (code) {
+ const key = `${namespace}.${code}`
+ const rawMetadata = extractApiErrorMetadata(err) ?? {}
+ const metadata = localizeMetadata(rawMetadata, t)
+ const translated = t(key, metadata)
+ // Vue i18n returns the key itself when missing; detect that and fall back.
+ if (translated !== key) return translated
+ // If the framework exposes `te`, use it to double-check.
+ const te = (t as TranslateWithExistsFn).te
+ if (te && te(key)) return translated
+ }
+ return extractApiErrorMessage(err, fallback)
+}
+
/**
* Extract a displayable error message from an API error.
*
diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue
index ee6a4c6d..27cb1f0c 100644
--- a/frontend/src/views/admin/SettingsView.vue
+++ b/frontend/src/views/admin/SettingsView.vue
@@ -2850,7 +2850,7 @@ import ProxySelector from '@/components/common/ProxySelector.vue'
import ImageUpload from '@/components/common/ImageUpload.vue'
import BackupSettings from '@/views/admin/BackupView.vue'
import { useClipboard } from '@/composables/useClipboard'
-import { extractApiErrorMessage } from '@/utils/apiError'
+import { extractApiErrorMessage, extractI18nErrorMessage } from '@/utils/apiError'
import { useAppStore } from '@/stores'
import { useAdminSettingsStore } from '@/stores/adminSettings'
import {
@@ -4085,14 +4085,10 @@ const cancelRateLimitModeOptions = computed(() => [
{ value: 'fixed', label: t('admin.settings.payment.cancelRateLimitWindowModeFixed') },
])
-const paymentErrorMap = computed(() => ({
- PENDING_ORDERS: t('payment.errors.PENDING_ORDERS'),
-}))
-
async function loadProviders() {
providersLoading.value = true
try { const res = await adminAPI.payment.getProviders(); providers.value = res.data || [] }
- catch (err: unknown) { appStore.showError(extractApiErrorMessage(err, t('common.error'))) }
+ catch (err: unknown) { appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) }
finally { providersLoading.value = false }
}
@@ -4122,7 +4118,7 @@ async function handleSaveProvider(payload: Partial) {
// Auto-save settings so provider changes take effect immediately
await saveSettings()
} catch (err: unknown) {
- appStore.showError(extractApiErrorMessage(err, t('common.error'), paymentErrorMap.value))
+ appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error')))
} finally {
providerSaving.value = false
}
@@ -4148,7 +4144,7 @@ async function handleToggleField(provider: ProviderInstance, field: 'enabled' |
} else {
provider.allow_user_refund = newValue
}
- } catch (err: unknown) { appStore.showError(extractApiErrorMessage(err, t('common.error'), paymentErrorMap.value)) }
+ } catch (err: unknown) { appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) }
}
async function handleToggleType(provider: ProviderInstance, type: string) {
@@ -4158,7 +4154,7 @@ async function handleToggleType(provider: ProviderInstance, type: string) {
try {
await adminAPI.payment.updateProvider(provider.id, { supported_types: updated } as any)
provider.supported_types = updated
- } catch (err: unknown) { appStore.showError(extractApiErrorMessage(err, t('common.error'), paymentErrorMap.value)) }
+ } catch (err: unknown) { appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) }
}
function confirmDeleteProvider(provider: ProviderInstance) {
@@ -4177,7 +4173,7 @@ async function handleReorderProviders(updates: { id: number; sort_order: number
if (p) p.sort_order = u.sort_order
}
} catch (err: unknown) {
- appStore.showError(extractApiErrorMessage(err, t('common.error')))
+ appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error')))
loadProviders()
}
}
@@ -4189,7 +4185,7 @@ async function handleDeleteProvider() {
appStore.showSuccess(t('common.deleted'))
showDeleteProviderDialog.value = false
loadProviders()
- } catch (err: unknown) { appStore.showError(extractApiErrorMessage(err, t('common.error'), paymentErrorMap.value)) }
+ } catch (err: unknown) { appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) }
}
onMounted(() => {
diff --git a/frontend/src/views/admin/orders/AdminOrdersView.vue b/frontend/src/views/admin/orders/AdminOrdersView.vue
index 027c8e5e..dd9fa7e6 100644
--- a/frontend/src/views/admin/orders/AdminOrdersView.vue
+++ b/frontend/src/views/admin/orders/AdminOrdersView.vue
@@ -116,7 +116,7 @@ import { ref, reactive, computed, onMounted } from 'vue'
import { useI18n } from 'vue-i18n'
import { useAppStore } from '@/stores/app'
import { adminPaymentAPI } from '@/api/admin/payment'
-import { extractApiErrorMessage } from '@/utils/apiError'
+import { extractI18nErrorMessage } from '@/utils/apiError'
import { formatOrderDateTime } from '@/components/payment/orderUtils'
import type { PaymentOrder } from '@/types/payment'
import AppLayout from '@/components/layout/AppLayout.vue'
@@ -167,7 +167,7 @@ async function loadOrders() {
orders.value = res.data.items || []
orderPagination.total = res.data.total || 0
} catch (err: unknown) {
- appStore.showError(extractApiErrorMessage(err, t('common.error')))
+ appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error')))
} finally { ordersLoading.value = false }
}
@@ -214,12 +214,12 @@ async function showOrderDetail(order: PaymentOrder) {
async function handleCancelOrder(order: PaymentOrder) {
try { await adminPaymentAPI.cancelOrder(order.id); appStore.showSuccess(t('payment.admin.orderCancelled')); loadOrders() }
- catch (err: unknown) { appStore.showError(extractApiErrorMessage(err, t('common.error'))) }
+ catch (err: unknown) { appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) }
}
async function handleRetryOrder(order: PaymentOrder) {
try { await adminPaymentAPI.retryRecharge(order.id); appStore.showSuccess(t('payment.admin.retrySuccess')); loadOrders() }
- catch (err: unknown) { appStore.showError(extractApiErrorMessage(err, t('common.error'))) }
+ catch (err: unknown) { appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) }
}
function openRefundDialog(order: PaymentOrder) { selectedOrder.value = order; showRefundDialog.value = true }
@@ -230,7 +230,7 @@ async function handleRefund(data: { amount: number; reason: string; deduct_balan
try {
await adminPaymentAPI.refundOrder(selectedOrder.value.id, { amount: data.amount, reason: data.reason, deduct_balance: data.deduct_balance, force: data.force })
appStore.showSuccess(t('payment.admin.refundSuccess')); showRefundDialog.value = false; loadOrders()
- } catch (err: unknown) { appStore.showError(extractApiErrorMessage(err, t('common.error'))) }
+ } catch (err: unknown) { appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) }
finally { refundSubmitting.value = false }
}
diff --git a/frontend/src/views/admin/orders/AdminPaymentDashboardView.vue b/frontend/src/views/admin/orders/AdminPaymentDashboardView.vue
index 06bc9218..5a80db44 100644
--- a/frontend/src/views/admin/orders/AdminPaymentDashboardView.vue
+++ b/frontend/src/views/admin/orders/AdminPaymentDashboardView.vue
@@ -72,7 +72,7 @@ import { ref, watch, onMounted } from 'vue'
import { useI18n } from 'vue-i18n'
import { useAppStore } from '@/stores/app'
import { adminPaymentAPI } from '@/api/admin/payment'
-import { extractApiErrorMessage } from '@/utils/apiError'
+import { extractI18nErrorMessage } from '@/utils/apiError'
import type { DashboardStats } from '@/types/payment'
import AppLayout from '@/components/layout/AppLayout.vue'
import LoadingSpinner from '@/components/common/LoadingSpinner.vue'
@@ -110,7 +110,7 @@ async function loadDashboard() {
const res = await adminPaymentAPI.getDashboard(days.value)
stats.value = res.data
} catch (err: unknown) {
- appStore.showError(extractApiErrorMessage(err, t('common.error')))
+ appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error')))
} finally {
loading.value = false
}
diff --git a/frontend/src/views/admin/orders/AdminPaymentPlansView.vue b/frontend/src/views/admin/orders/AdminPaymentPlansView.vue
index 876b2aa1..c2fc26fe 100644
--- a/frontend/src/views/admin/orders/AdminPaymentPlansView.vue
+++ b/frontend/src/views/admin/orders/AdminPaymentPlansView.vue
@@ -78,7 +78,7 @@ import { ref, computed, onMounted } from 'vue'
import { useI18n } from 'vue-i18n'
import { useAppStore } from '@/stores/app'
import { adminPaymentAPI } from '@/api/admin/payment'
-import { extractApiErrorMessage } from '@/utils/apiError'
+import { extractI18nErrorMessage } from '@/utils/apiError'
import adminAPI from '@/api/admin'
import type { SubscriptionPlan } from '@/types/payment'
import type { AdminGroup } from '@/types'
@@ -150,7 +150,7 @@ async function loadPlans() {
: (p.features || []),
}))
}
- catch (err: unknown) { appStore.showError(extractApiErrorMessage(err, t('common.error'))) }
+ catch (err: unknown) { appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) }
finally { plansLoading.value = false }
}
@@ -166,7 +166,7 @@ async function toggleForSale(plan: SubscriptionPlan) {
await adminPaymentAPI.updatePlan(plan.id, { for_sale: !plan.for_sale })
plan.for_sale = !plan.for_sale
} catch (err: unknown) {
- appStore.showError(extractApiErrorMessage(err, t('common.error')))
+ appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error')))
}
}
@@ -174,7 +174,7 @@ function confirmDeletePlan(plan: SubscriptionPlan) { deletingPlanId.value = plan
async function handleDeletePlan() {
if (!deletingPlanId.value) return
try { await adminPaymentAPI.deletePlan(deletingPlanId.value); appStore.showSuccess(t('common.deleted')); showDeletePlanDialog.value = false; loadPlans() }
- catch (err: unknown) { appStore.showError(extractApiErrorMessage(err, t('common.error'))) }
+ catch (err: unknown) { appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) }
}
// ==================== Lifecycle ====================
diff --git a/frontend/src/views/user/PaymentQRCodeView.vue b/frontend/src/views/user/PaymentQRCodeView.vue
index 0965947a..f844858d 100644
--- a/frontend/src/views/user/PaymentQRCodeView.vue
+++ b/frontend/src/views/user/PaymentQRCodeView.vue
@@ -39,7 +39,7 @@ import { useRoute, useRouter } from 'vue-router'
import AppLayout from '@/components/layout/AppLayout.vue'
import { usePaymentStore } from '@/stores/payment'
import { paymentAPI } from '@/api/payment'
-import { extractApiErrorMessage } from '@/utils/apiError'
+import { extractI18nErrorMessage } from '@/utils/apiError'
import { useAppStore } from '@/stores'
import QRCode from 'qrcode'
import alipayIcon from '@/assets/icons/alipay.svg'
@@ -167,7 +167,7 @@ async function handleCancel() {
cleanup()
router.push('/purchase')
} catch (err: unknown) {
- appStore.showError(extractApiErrorMessage(err, t('common.error')))
+ appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error')))
} finally {
cancelling.value = false
}
diff --git a/frontend/src/views/user/PaymentView.vue b/frontend/src/views/user/PaymentView.vue
index 3f1401b3..e2885c80 100644
--- a/frontend/src/views/user/PaymentView.vue
+++ b/frontend/src/views/user/PaymentView.vue
@@ -271,7 +271,7 @@ import { usePaymentStore } from '@/stores/payment'
import { useSubscriptionStore } from '@/stores/subscriptions'
import { useAppStore } from '@/stores'
import { paymentAPI } from '@/api/payment'
-import { extractApiErrorMessage } from '@/utils/apiError'
+import { extractI18nErrorMessage } from '@/utils/apiError'
import { isMobileDevice } from '@/utils/device'
import type { SubscriptionPlan, CheckoutInfoResponse, OrderType } from '@/types/payment'
import AppLayout from '@/components/layout/AppLayout.vue'
@@ -610,7 +610,7 @@ async function createOrder(orderAmount: number, orderType: OrderType, planId?: n
} else if (apiErr.reason === 'CANCEL_RATE_LIMITED') {
errorMessage.value = t('payment.errors.cancelRateLimited')
} else {
- errorMessage.value = extractApiErrorMessage(err, t('payment.result.failed'))
+ errorMessage.value = extractI18nErrorMessage(err, t, 'payment.errors', t('payment.result.failed'))
}
appStore.showError(errorMessage.value)
} finally {
@@ -648,7 +648,7 @@ onMounted(async () => {
}
}
}
- } catch (err: unknown) { appStore.showError(extractApiErrorMessage(err, t('common.error'))) }
+ } catch (err: unknown) { appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) }
finally { loading.value = false }
// Fetch active subscriptions (uses cache, non-blocking)
subscriptionStore.fetchActiveSubscriptions().catch(() => {})
diff --git a/frontend/src/views/user/StripePaymentView.vue b/frontend/src/views/user/StripePaymentView.vue
index 20a4a408..3f73d4d5 100644
--- a/frontend/src/views/user/StripePaymentView.vue
+++ b/frontend/src/views/user/StripePaymentView.vue
@@ -99,7 +99,7 @@ import { useI18n } from 'vue-i18n'
import { useRoute, useRouter } from 'vue-router'
import { usePaymentStore } from '@/stores/payment'
import { paymentAPI } from '@/api/payment'
-import { extractApiErrorMessage } from '@/utils/apiError'
+import { extractI18nErrorMessage } from '@/utils/apiError'
import { isMobileDevice } from '@/utils/device'
import type { PaymentOrder } from '@/types/payment'
import type { Stripe, StripeElements } from '@stripe/stripe-js'
@@ -167,7 +167,7 @@ onMounted(async () => {
mountPaymentElement(stripe, clientSecret)
}
} catch (err: unknown) {
- initError.value = extractApiErrorMessage(err, t('payment.stripeLoadFailed'))
+ initError.value = extractI18nErrorMessage(err, t, 'payment.errors', t('payment.stripeLoadFailed'))
} finally {
loading.value = false
}
@@ -248,7 +248,7 @@ async function handleGenericPay() {
scheduleClose()
}
} catch (err: unknown) {
- stripeError.value = extractApiErrorMessage(err, t('payment.result.failed'))
+ stripeError.value = extractI18nErrorMessage(err, t, 'payment.errors', t('payment.result.failed'))
} finally {
stripeSubmitting.value = false
}
diff --git a/frontend/src/views/user/StripePopupView.vue b/frontend/src/views/user/StripePopupView.vue
index 2704c62d..688ad644 100644
--- a/frontend/src/views/user/StripePopupView.vue
+++ b/frontend/src/views/user/StripePopupView.vue
@@ -56,7 +56,7 @@
import { computed, ref, onMounted, onUnmounted } from 'vue'
import { useI18n } from 'vue-i18n'
import { useRoute } from 'vue-router'
-import { extractApiErrorMessage } from '@/utils/apiError'
+import { extractI18nErrorMessage } from '@/utils/apiError'
import { isMobileDevice } from '@/utils/device'
interface StripeWithWechatPay {
@@ -143,7 +143,7 @@ async function initStripe(clientSecret: string, publishableKey: string) {
}
}
} catch (err: unknown) {
- error.value = extractApiErrorMessage(err, t('payment.stripeLoadFailed'))
+ error.value = extractI18nErrorMessage(err, t, 'payment.errors', t('payment.stripeLoadFailed'))
}
}
diff --git a/frontend/src/views/user/UserOrdersView.vue b/frontend/src/views/user/UserOrdersView.vue
index ea888eb7..c3ed80eb 100644
--- a/frontend/src/views/user/UserOrdersView.vue
+++ b/frontend/src/views/user/UserOrdersView.vue
@@ -86,7 +86,7 @@ import { useI18n } from 'vue-i18n'
import { useRouter } from 'vue-router'
import { useAppStore } from '@/stores'
import { paymentAPI } from '@/api/payment'
-import { extractApiErrorMessage } from '@/utils/apiError'
+import { extractI18nErrorMessage } from '@/utils/apiError'
import type { PaymentOrder } from '@/types/payment'
import AppLayout from '@/components/layout/AppLayout.vue'
import Pagination from '@/components/common/Pagination.vue'
@@ -128,7 +128,7 @@ async function fetchOrders() {
orders.value = res.data.items || []
pagination.total = res.data.total || 0
} catch (err: unknown) {
- appStore.showError(extractApiErrorMessage(err, t('common.error')))
+ appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error')))
} finally {
loading.value = false
}
@@ -148,7 +148,7 @@ async function confirmCancel() {
cancelTargetId.value = null
await fetchOrders()
} catch (err: unknown) {
- appStore.showError(extractApiErrorMessage(err, t('common.error')))
+ appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error')))
} finally {
actionLoading.value = false
}
@@ -166,7 +166,7 @@ async function confirmRefund() {
refundReason.value = ''
await fetchOrders()
} catch (err: unknown) {
- appStore.showError(extractApiErrorMessage(err, t('common.error')))
+ appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error')))
} finally {
actionLoading.value = false
}
--
GitLab
From 97c9b992cbf8b658b6ef27c27fd0041893b74317 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Mon, 20 Apr 2026 20:27:15 +0800
Subject: [PATCH 048/261] fix: require wechat unionid for oauth identity
---
backend/internal/handler/auth_wechat_oauth.go | 6 +-
.../handler/auth_wechat_oauth_test.go | 107 +++++++++++-------
2 files changed, 66 insertions(+), 47 deletions(-)
diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go
index 816f60fd..f0755f1f 100644
--- a/backend/internal/handler/auth_wechat_oauth.go
+++ b/backend/internal/handler/auth_wechat_oauth.go
@@ -193,11 +193,11 @@ func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) {
unionid := strings.TrimSpace(firstNonEmpty(userInfo.UnionID, tokenResp.UnionID))
openid := strings.TrimSpace(firstNonEmpty(userInfo.OpenID, tokenResp.OpenID))
- providerSubject := firstNonEmpty(unionid, openid)
- if providerSubject == "" {
- redirectOAuthError(c, frontendCallback, "provider_error", "wechat_missing_subject", "")
+ if unionid == "" {
+ redirectOAuthError(c, frontendCallback, "provider_error", "wechat_missing_unionid", "")
return
}
+ providerSubject := unionid
username := firstNonEmpty(userInfo.Nickname, wechatFallbackUsername(providerSubject))
email := wechatSyntheticEmail(providerSubject)
diff --git a/backend/internal/handler/auth_wechat_oauth_test.go b/backend/internal/handler/auth_wechat_oauth_test.go
index 0d1df1b6..1ff80e1b 100644
--- a/backend/internal/handler/auth_wechat_oauth_test.go
+++ b/backend/internal/handler/auth_wechat_oauth_test.go
@@ -6,7 +6,6 @@ import (
"bytes"
"context"
"database/sql"
- "encoding/base64"
"net/http"
"net/http/httptest"
"net/url"
@@ -122,6 +121,59 @@ func TestWeChatOAuthCallbackCreatesPendingSessionForUnifiedFlow(t *testing.T) {
require.Equal(t, "openid-123", session.UpstreamIdentityClaims["openid"])
}
+func TestWeChatOAuthCallbackRejectsMissingUnionID(t *testing.T) {
+ t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
+ t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
+ t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "https://app.example.com/auth/wechat/callback")
+
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ originalUserInfoURL := wechatOAuthUserInfoURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ wechatOAuthUserInfoURL = originalUserInfoURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","scope":"snsapi_login"}`))
+ case strings.Contains(r.URL.Path, "/sns/userinfo"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"openid":"openid-123","nickname":"WeChat Nick","headimgurl":"https://cdn.example/avatar.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+ wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
+
+ handler, client := newWeChatOAuthTestHandler(t, false)
+ defer client.Close()
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
+ req.Host = "api.example.com"
+ req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
+ req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
+ req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ c.Request = req
+
+ handler.WeChatOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Contains(t, recorder.Header().Get("Location"), "#error=provider_error")
+ require.Contains(t, recorder.Header().Get("Location"), "error_message=wechat_missing_unionid")
+ require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
+
+ count, err := client.PendingAuthSession.Query().Count(context.Background())
+ require.NoError(t, err)
+ require.Zero(t, count)
+}
+
func TestWeChatOAuthCallbackBindUsesUnionCanonicalIdentityAcrossChannels(t *testing.T) {
testCases := []struct {
name string
@@ -542,12 +594,7 @@ func newWeChatOAuthTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandl
userRepo := &oauthPendingFlowUserRepo{client: client}
redeemRepo := repository.NewRedeemCodeRepository(client)
- settingSvc := service.NewSettingService(&wechatOAuthSettingRepoStub{
- values: map[string]string{
- service.SettingKeyRegistrationEnabled: "true",
- service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled),
- },
- }, &config.Config{
+ cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
ExpireHour: 1,
@@ -558,25 +605,20 @@ func newWeChatOAuthTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandl
UserBalance: 0,
UserConcurrency: 1,
},
- })
+ }
+ settingSvc := service.NewSettingService(&wechatOAuthSettingRepoStub{
+ values: map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled),
+ },
+ }, cfg)
authSvc := service.NewAuthService(
client,
userRepo,
redeemRepo,
&wechatOAuthRefreshTokenCacheStub{},
- &config.Config{
- JWT: config.JWTConfig{
- Secret: "test-secret",
- ExpireHour: 1,
- AccessTokenExpireMinutes: 60,
- RefreshTokenExpireDays: 7,
- },
- Default: config.DefaultConfig{
- UserBalance: 0,
- UserConcurrency: 1,
- },
- },
+ cfg,
settingSvc,
nil,
nil,
@@ -588,33 +630,10 @@ func newWeChatOAuthTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandl
return &AuthHandler{
authService: authSvc,
settingSvc: settingSvc,
+ cfg: cfg,
}, client
}
-func encodedCookie(name, value string) *http.Cookie {
- return &http.Cookie{
- Name: name,
- Value: encodeCookieValue(value),
- Path: "/",
- }
-}
-
-func findCookie(cookies []*http.Cookie, name string) *http.Cookie {
- for _, cookie := range cookies {
- if cookie.Name == name {
- return cookie
- }
- }
- return nil
-}
-
-func decodeCookieValueForTest(t *testing.T, value string) string {
- t.Helper()
- raw, err := base64.RawURLEncoding.DecodeString(value)
- require.NoError(t, err)
- return string(raw)
-}
-
func assertOAuthRedirectError(t *testing.T, location string, errorCode string, errorMessage string) {
t.Helper()
--
GitLab
From 7c7924e9fa6125f47e2b080496466ab22cec40ad Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Mon, 20 Apr 2026 20:31:19 +0800
Subject: [PATCH 049/261] fix: guard payment fulfillment provider mismatch
---
.../internal/service/payment_fulfillment.go | 25 ++++++++
.../service/payment_fulfillment_test.go | 64 +++++++++++++++++++
2 files changed, 89 insertions(+)
diff --git a/backend/internal/service/payment_fulfillment.go b/backend/internal/service/payment_fulfillment.go
index 44818b37..519455f0 100644
--- a/backend/internal/service/payment_fulfillment.go
+++ b/backend/internal/service/payment_fulfillment.go
@@ -41,6 +41,19 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo
slog.Error("order not found", "orderID", oid)
return nil
}
+ instanceProviderKey := ""
+ if inst, instErr := s.getOrderProviderInstance(ctx, o); instErr == nil && inst != nil {
+ instanceProviderKey = inst.ProviderKey
+ }
+ expectedProviderKey := expectedNotificationProviderKey(s.registry, o.PaymentType, instanceProviderKey)
+ if expectedProviderKey != "" && strings.TrimSpace(pk) != "" && !strings.EqualFold(expectedProviderKey, strings.TrimSpace(pk)) {
+ s.writeAuditLog(ctx, o.ID, "PAYMENT_PROVIDER_MISMATCH", pk, map[string]any{
+ "expectedProvider": expectedProviderKey,
+ "actualProvider": pk,
+ "tradeNo": tradeNo,
+ })
+ return fmt.Errorf("provider mismatch: expected %s, got %s", expectedProviderKey, pk)
+ }
// Skip amount check when paid=0 (e.g. QueryOrder doesn't return amount).
// Also skip if paid is NaN/Inf (malformed provider data).
if paid > 0 && !math.IsNaN(paid) && !math.IsInf(paid, 0) {
@@ -56,6 +69,18 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo
return s.toPaid(ctx, o, tradeNo, paid, pk)
}
+func expectedNotificationProviderKey(registry *payment.Registry, orderPaymentType string, instanceProviderKey string) string {
+ if key := strings.TrimSpace(instanceProviderKey); key != "" {
+ return key
+ }
+ if registry != nil {
+ if key := strings.TrimSpace(registry.GetProviderKey(payment.PaymentType(orderPaymentType))); key != "" {
+ return key
+ }
+ }
+ return strings.TrimSpace(orderPaymentType)
+}
+
func (s *PaymentService) toPaid(ctx context.Context, o *dbent.PaymentOrder, tradeNo string, paid float64, pk string) error {
previousStatus := o.Status
now := time.Now()
diff --git a/backend/internal/service/payment_fulfillment_test.go b/backend/internal/service/payment_fulfillment_test.go
index 625b0d9f..4cc00301 100644
--- a/backend/internal/service/payment_fulfillment_test.go
+++ b/backend/internal/service/payment_fulfillment_test.go
@@ -3,12 +3,37 @@
package service
import (
+ "context"
"errors"
"testing"
+ "github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/stretchr/testify/assert"
)
+type paymentFulfillmentTestProvider struct {
+ key string
+ supportedTypes []payment.PaymentType
+}
+
+func (p paymentFulfillmentTestProvider) Name() string { return p.key }
+func (p paymentFulfillmentTestProvider) ProviderKey() string { return p.key }
+func (p paymentFulfillmentTestProvider) SupportedTypes() []payment.PaymentType {
+ return p.supportedTypes
+}
+func (p paymentFulfillmentTestProvider) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
+ panic("unexpected call")
+}
+func (p paymentFulfillmentTestProvider) QueryOrder(ctx context.Context, tradeNo string) (*payment.QueryOrderResponse, error) {
+ panic("unexpected call")
+}
+func (p paymentFulfillmentTestProvider) VerifyNotification(ctx context.Context, rawBody string, headers map[string]string) (*payment.PaymentNotification, error) {
+ panic("unexpected call")
+}
+func (p paymentFulfillmentTestProvider) Refund(ctx context.Context, req payment.RefundRequest) (*payment.RefundResponse, error) {
+ panic("unexpected call")
+}
+
// ---------------------------------------------------------------------------
// resolveRedeemAction — pure idempotency decision logic
// ---------------------------------------------------------------------------
@@ -161,3 +186,42 @@ func TestResolveRedeemAction_IsUsedCanUseConsistency(t *testing.T) {
assert.True(t, unusedCode.CanUse())
assert.Equal(t, redeemActionRedeem, resolveRedeemAction(unusedCode, nil))
}
+
+func TestExpectedNotificationProviderKeyPrefersOrderInstanceProvider(t *testing.T) {
+ t.Parallel()
+
+ registry := payment.NewRegistry()
+ registry.Register(paymentFulfillmentTestProvider{
+ key: payment.TypeAlipay,
+ supportedTypes: []payment.PaymentType{payment.TypeAlipay},
+ })
+
+ assert.Equal(t,
+ payment.TypeEasyPay,
+ expectedNotificationProviderKey(registry, payment.TypeAlipay, payment.TypeEasyPay),
+ )
+}
+
+func TestExpectedNotificationProviderKeyUsesRegistryMappingForLegacyOrders(t *testing.T) {
+ t.Parallel()
+
+ registry := payment.NewRegistry()
+ registry.Register(paymentFulfillmentTestProvider{
+ key: payment.TypeEasyPay,
+ supportedTypes: []payment.PaymentType{payment.TypeAlipay},
+ })
+
+ assert.Equal(t,
+ payment.TypeEasyPay,
+ expectedNotificationProviderKey(registry, payment.TypeAlipay, ""),
+ )
+}
+
+func TestExpectedNotificationProviderKeyFallsBackToPaymentType(t *testing.T) {
+ t.Parallel()
+
+ assert.Equal(t,
+ payment.TypeWxpay,
+ expectedNotificationProviderKey(nil, payment.TypeWxpay, ""),
+ )
+}
--
GitLab
From e3f69e02464fa79153aa16db7a457359f6a9a8a3 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Mon, 20 Apr 2026 20:42:01 +0800
Subject: [PATCH 050/261] fix: tighten webhook provider resolution
---
backend/internal/service/payment_service.go | 20 ---
.../service/payment_webhook_provider.go | 86 +++++++++++
.../service/payment_webhook_provider_test.go | 141 ++++++++++++++++++
3 files changed, 227 insertions(+), 20 deletions(-)
create mode 100644 backend/internal/service/payment_webhook_provider.go
create mode 100644 backend/internal/service/payment_webhook_provider_test.go
diff --git a/backend/internal/service/payment_service.go b/backend/internal/service/payment_service.go
index e897741a..d3175ba6 100644
--- a/backend/internal/service/payment_service.go
+++ b/backend/internal/service/payment_service.go
@@ -9,7 +9,6 @@ import (
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
- "github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/Wei-Shaw/sub2api/internal/payment/provider"
@@ -225,25 +224,6 @@ func (s *PaymentService) loadProviders(ctx context.Context) {
}
}
-// GetWebhookProvider returns the provider instance that should verify a webhook.
-// It extracts out_trade_no from the raw body, looks up the order to find the
-// original provider instance, and creates a provider with that instance's credentials.
-// Falls back to the registry provider when the order cannot be found.
-func (s *PaymentService) GetWebhookProvider(ctx context.Context, providerKey, outTradeNo string) (payment.Provider, error) {
- if outTradeNo != "" {
- order, err := s.entClient.PaymentOrder.Query().Where(paymentorder.OutTradeNo(outTradeNo)).Only(ctx)
- if err == nil {
- p, pErr := s.getOrderProvider(ctx, order)
- if pErr == nil {
- return p, nil
- }
- slog.Warn("[Webhook] order provider creation failed, falling back to registry", "outTradeNo", outTradeNo, "error", pErr)
- }
- }
- s.EnsureProviders(ctx)
- return s.registry.GetProviderByKey(providerKey)
-}
-
// --- Helpers ---
func psIsRefundStatus(s string) bool {
diff --git a/backend/internal/service/payment_webhook_provider.go b/backend/internal/service/payment_webhook_provider.go
new file mode 100644
index 00000000..a877db2b
--- /dev/null
+++ b/backend/internal/service/payment_webhook_provider.go
@@ -0,0 +1,86 @@
+package service
+
+import (
+ "context"
+ "fmt"
+ "log/slog"
+ "strconv"
+ "strings"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/paymentorder"
+ "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ "github.com/Wei-Shaw/sub2api/internal/payment/provider"
+)
+
+// GetWebhookProvider returns the provider instance that should verify a webhook.
+// It resolves the original provider instance from the order whenever possible and
+// only falls back to a registry provider for legacy/single-instance scenarios.
+func (s *PaymentService) GetWebhookProvider(ctx context.Context, providerKey, outTradeNo string) (payment.Provider, error) {
+ if outTradeNo != "" {
+ order, err := s.entClient.PaymentOrder.Query().Where(paymentorder.OutTradeNo(outTradeNo)).Only(ctx)
+ if err == nil {
+ if psHasPinnedProviderInstance(order) {
+ return s.getPinnedOrderProvider(ctx, order)
+ }
+ if !s.webhookRegistryFallbackAllowed(ctx, providerKey) {
+ return nil, fmt.Errorf("webhook provider fallback is ambiguous for %s", providerKey)
+ }
+ s.EnsureProviders(ctx)
+ return s.registry.GetProviderByKey(providerKey)
+ }
+ }
+
+ if !s.webhookRegistryFallbackAllowed(ctx, providerKey) {
+ return nil, fmt.Errorf("webhook provider fallback is ambiguous for %s", providerKey)
+ }
+
+ s.EnsureProviders(ctx)
+ return s.registry.GetProviderByKey(providerKey)
+}
+
+func (s *PaymentService) getPinnedOrderProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) {
+ inst, err := s.getOrderProviderInstance(ctx, o)
+ if err != nil {
+ return nil, fmt.Errorf("load order provider instance: %w", err)
+ }
+ if inst == nil {
+ return nil, fmt.Errorf("order %d provider instance is missing", o.ID)
+ }
+
+ instID := strconv.FormatInt(int64(inst.ID), 10)
+ cfg, err := s.loadBalancer.GetInstanceConfig(ctx, int64(inst.ID))
+ if err != nil {
+ return nil, fmt.Errorf("load provider instance config: %w", err)
+ }
+
+ prov, err := provider.CreateProvider(inst.ProviderKey, instID, cfg)
+ if err != nil {
+ return nil, fmt.Errorf("create pinned provider: %w", err)
+ }
+ return prov, nil
+}
+
+func (s *PaymentService) webhookRegistryFallbackAllowed(ctx context.Context, providerKey string) bool {
+ providerKey = strings.TrimSpace(providerKey)
+ if providerKey == "" || s == nil || s.entClient == nil {
+ return false
+ }
+
+ count, err := s.entClient.PaymentProviderInstance.Query().
+ Where(
+ paymentproviderinstance.ProviderKeyEQ(providerKey),
+ paymentproviderinstance.EnabledEQ(true),
+ ).
+ Count(ctx)
+ if err != nil {
+ slog.Warn("payment webhook fallback instance count failed", "provider", providerKey, "error", err)
+ return false
+ }
+ return count <= 1
+}
+
+func psHasPinnedProviderInstance(order *dbent.PaymentOrder) bool {
+ return order != nil && order.ProviderInstanceID != nil && strings.TrimSpace(*order.ProviderInstanceID) != ""
+}
diff --git a/backend/internal/service/payment_webhook_provider_test.go b/backend/internal/service/payment_webhook_provider_test.go
new file mode 100644
index 00000000..85c296de
--- /dev/null
+++ b/backend/internal/service/payment_webhook_provider_test.go
@@ -0,0 +1,141 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ "github.com/stretchr/testify/require"
+)
+
+type webhookProviderTestDouble struct {
+ key string
+ types []payment.PaymentType
+}
+
+func (p webhookProviderTestDouble) Name() string { return p.key }
+func (p webhookProviderTestDouble) ProviderKey() string { return p.key }
+func (p webhookProviderTestDouble) SupportedTypes() []payment.PaymentType { return p.types }
+func (p webhookProviderTestDouble) CreatePayment(context.Context, payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
+ panic("unexpected call")
+}
+func (p webhookProviderTestDouble) QueryOrder(context.Context, string) (*payment.QueryOrderResponse, error) {
+ panic("unexpected call")
+}
+func (p webhookProviderTestDouble) VerifyNotification(context.Context, string, map[string]string) (*payment.PaymentNotification, error) {
+ panic("unexpected call")
+}
+func (p webhookProviderTestDouble) Refund(context.Context, payment.RefundRequest) (*payment.RefundResponse, error) {
+ panic("unexpected call")
+}
+
+func TestGetWebhookProviderRejectsAmbiguousRegistryFallback(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ _, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeWxpay).
+ SetName("wxpay-a").
+ SetConfig("{}").
+ SetSupportedTypes("wxpay").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+ _, err = client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeWxpay).
+ SetName("wxpay-b").
+ SetConfig("{}").
+ SetSupportedTypes("wxpay").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ svc := &PaymentService{
+ entClient: client,
+ registry: payment.NewRegistry(),
+ providersLoaded: true,
+ }
+
+ _, err = svc.GetWebhookProvider(ctx, payment.TypeWxpay, "")
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "ambiguous")
+}
+
+func TestGetWebhookProviderAllowsSingleInstanceRegistryFallback(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ _, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeStripe).
+ SetName("stripe-a").
+ SetConfig("{}").
+ SetSupportedTypes("stripe").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ registry := payment.NewRegistry()
+ registry.Register(webhookProviderTestDouble{
+ key: payment.TypeStripe,
+ types: []payment.PaymentType{payment.TypeStripe},
+ })
+
+ svc := &PaymentService{
+ entClient: client,
+ registry: registry,
+ providersLoaded: true,
+ }
+
+ prov, err := svc.GetWebhookProvider(ctx, payment.TypeStripe, "")
+ require.NoError(t, err)
+ require.Equal(t, payment.TypeStripe, prov.ProviderKey())
+}
+
+func TestGetWebhookProviderRejectsRegistryFallbackForPinnedOrder(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ user, err := client.User.Create().
+ SetEmail("webhook@example.com").
+ SetPasswordHash("hash").
+ SetUsername("webhook").
+ Save(ctx)
+ require.NoError(t, err)
+
+ pinnedInstanceID := "999"
+ _, err = client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(88).
+ SetPayAmount(88).
+ SetFeeRate(0).
+ SetRechargeCode("TEST-RECHARGE").
+ SetOutTradeNo("sub2_test_pinned_order").
+ SetPaymentType(payment.TypeWxpay).
+ SetPaymentTradeNo("").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusPending).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ SetProviderInstanceID(pinnedInstanceID).
+ Save(ctx)
+ require.NoError(t, err)
+
+ registry := payment.NewRegistry()
+ registry.Register(webhookProviderTestDouble{
+ key: payment.TypeWxpay,
+ types: []payment.PaymentType{payment.TypeWxpay},
+ })
+
+ svc := &PaymentService{
+ entClient: client,
+ registry: registry,
+ providersLoaded: true,
+ }
+
+ _, err = svc.GetWebhookProvider(ctx, payment.TypeWxpay, "sub2_test_pinned_order")
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "provider instance")
+}
--
GitLab
From c0b24aefba926c61d749af11cd01cd6f96bb65fe Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Mon, 20 Apr 2026 20:47:14 +0800
Subject: [PATCH 051/261] feat: snapshot payment provider keys on orders
---
backend/ent/migrate/schema.go | 15 ++--
backend/ent/mutation.go | 75 ++++++++++++++++-
backend/ent/paymentorder.go | 16 +++-
backend/ent/paymentorder/paymentorder.go | 10 +++
backend/ent/paymentorder/where.go | 80 ++++++++++++++++++
backend/ent/paymentorder_create.go | 83 +++++++++++++++++++
backend/ent/paymentorder_update.go | 62 ++++++++++++++
backend/ent/runtime/runtime.go | 20 +++--
backend/ent/schema/payment_order.go | 4 +
.../internal/service/payment_fulfillment.go | 7 +-
.../service/payment_fulfillment_test.go | 21 ++++-
backend/internal/service/payment_order.go | 8 +-
.../service/payment_order_lifecycle.go | 13 ++-
...dd_payment_order_provider_key_snapshot.sql | 10 +++
14 files changed, 400 insertions(+), 24 deletions(-)
create mode 100644 backend/migrations/112_add_payment_order_provider_key_snapshot.sql
diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go
index bf41e73b..230ea060 100644
--- a/backend/ent/migrate/schema.go
+++ b/backend/ent/migrate/schema.go
@@ -654,6 +654,7 @@ var (
{Name: "subscription_group_id", Type: field.TypeInt64, Nullable: true},
{Name: "subscription_days", Type: field.TypeInt, Nullable: true},
{Name: "provider_instance_id", Type: field.TypeString, Nullable: true, Size: 64},
+ {Name: "provider_key", Type: field.TypeString, Nullable: true, Size: 30},
{Name: "status", Type: field.TypeString, Size: 30, Default: "PENDING"},
{Name: "refund_amount", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,2)"}},
{Name: "refund_reason", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
@@ -682,7 +683,7 @@ var (
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "payment_orders_users_payment_orders",
- Columns: []*schema.Column{PaymentOrdersColumns[37]},
+ Columns: []*schema.Column{PaymentOrdersColumns[38]},
RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction,
},
@@ -696,32 +697,32 @@ var (
{
Name: "paymentorder_user_id",
Unique: false,
- Columns: []*schema.Column{PaymentOrdersColumns[37]},
+ Columns: []*schema.Column{PaymentOrdersColumns[38]},
},
{
Name: "paymentorder_status",
Unique: false,
- Columns: []*schema.Column{PaymentOrdersColumns[19]},
+ Columns: []*schema.Column{PaymentOrdersColumns[20]},
},
{
Name: "paymentorder_expires_at",
Unique: false,
- Columns: []*schema.Column{PaymentOrdersColumns[27]},
+ Columns: []*schema.Column{PaymentOrdersColumns[28]},
},
{
Name: "paymentorder_created_at",
Unique: false,
- Columns: []*schema.Column{PaymentOrdersColumns[35]},
+ Columns: []*schema.Column{PaymentOrdersColumns[36]},
},
{
Name: "paymentorder_paid_at",
Unique: false,
- Columns: []*schema.Column{PaymentOrdersColumns[28]},
+ Columns: []*schema.Column{PaymentOrdersColumns[29]},
},
{
Name: "paymentorder_payment_type_paid_at",
Unique: false,
- Columns: []*schema.Column{PaymentOrdersColumns[9], PaymentOrdersColumns[28]},
+ Columns: []*schema.Column{PaymentOrdersColumns[9], PaymentOrdersColumns[29]},
},
{
Name: "paymentorder_order_type",
diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go
index 12905c9a..5227015c 100644
--- a/backend/ent/mutation.go
+++ b/backend/ent/mutation.go
@@ -15385,6 +15385,7 @@ type PaymentOrderMutation struct {
subscription_days *int
addsubscription_days *int
provider_instance_id *string
+ provider_key *string
status *string
refund_amount *float64
addrefund_amount *float64
@@ -16421,6 +16422,55 @@ func (m *PaymentOrderMutation) ResetProviderInstanceID() {
delete(m.clearedFields, paymentorder.FieldProviderInstanceID)
}
+// SetProviderKey sets the "provider_key" field.
+func (m *PaymentOrderMutation) SetProviderKey(s string) {
+ m.provider_key = &s
+}
+
+// ProviderKey returns the value of the "provider_key" field in the mutation.
+func (m *PaymentOrderMutation) ProviderKey() (r string, exists bool) {
+ v := m.provider_key
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderKey returns the old "provider_key" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldProviderKey(ctx context.Context) (v *string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderKey is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderKey requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderKey: %w", err)
+ }
+ return oldValue.ProviderKey, nil
+}
+
+// ClearProviderKey clears the value of the "provider_key" field.
+func (m *PaymentOrderMutation) ClearProviderKey() {
+ m.provider_key = nil
+ m.clearedFields[paymentorder.FieldProviderKey] = struct{}{}
+}
+
+// ProviderKeyCleared returns if the "provider_key" field was cleared in this mutation.
+func (m *PaymentOrderMutation) ProviderKeyCleared() bool {
+ _, ok := m.clearedFields[paymentorder.FieldProviderKey]
+ return ok
+}
+
+// ResetProviderKey resets all changes to the "provider_key" field.
+func (m *PaymentOrderMutation) ResetProviderKey() {
+ m.provider_key = nil
+ delete(m.clearedFields, paymentorder.FieldProviderKey)
+}
+
// SetStatus sets the "status" field.
func (m *PaymentOrderMutation) SetStatus(s string) {
m.status = &s
@@ -17280,7 +17330,7 @@ func (m *PaymentOrderMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *PaymentOrderMutation) Fields() []string {
- fields := make([]string, 0, 37)
+ fields := make([]string, 0, 38)
if m.user != nil {
fields = append(fields, paymentorder.FieldUserID)
}
@@ -17338,6 +17388,9 @@ func (m *PaymentOrderMutation) Fields() []string {
if m.provider_instance_id != nil {
fields = append(fields, paymentorder.FieldProviderInstanceID)
}
+ if m.provider_key != nil {
+ fields = append(fields, paymentorder.FieldProviderKey)
+ }
if m.status != nil {
fields = append(fields, paymentorder.FieldStatus)
}
@@ -17438,6 +17491,8 @@ func (m *PaymentOrderMutation) Field(name string) (ent.Value, bool) {
return m.SubscriptionDays()
case paymentorder.FieldProviderInstanceID:
return m.ProviderInstanceID()
+ case paymentorder.FieldProviderKey:
+ return m.ProviderKey()
case paymentorder.FieldStatus:
return m.Status()
case paymentorder.FieldRefundAmount:
@@ -17521,6 +17576,8 @@ func (m *PaymentOrderMutation) OldField(ctx context.Context, name string) (ent.V
return m.OldSubscriptionDays(ctx)
case paymentorder.FieldProviderInstanceID:
return m.OldProviderInstanceID(ctx)
+ case paymentorder.FieldProviderKey:
+ return m.OldProviderKey(ctx)
case paymentorder.FieldStatus:
return m.OldStatus(ctx)
case paymentorder.FieldRefundAmount:
@@ -17699,6 +17756,13 @@ func (m *PaymentOrderMutation) SetField(name string, value ent.Value) error {
}
m.SetProviderInstanceID(v)
return nil
+ case paymentorder.FieldProviderKey:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderKey(v)
+ return nil
case paymentorder.FieldStatus:
v, ok := value.(string)
if !ok {
@@ -17966,6 +18030,9 @@ func (m *PaymentOrderMutation) ClearedFields() []string {
if m.FieldCleared(paymentorder.FieldProviderInstanceID) {
fields = append(fields, paymentorder.FieldProviderInstanceID)
}
+ if m.FieldCleared(paymentorder.FieldProviderKey) {
+ fields = append(fields, paymentorder.FieldProviderKey)
+ }
if m.FieldCleared(paymentorder.FieldRefundReason) {
fields = append(fields, paymentorder.FieldRefundReason)
}
@@ -18034,6 +18101,9 @@ func (m *PaymentOrderMutation) ClearField(name string) error {
case paymentorder.FieldProviderInstanceID:
m.ClearProviderInstanceID()
return nil
+ case paymentorder.FieldProviderKey:
+ m.ClearProviderKey()
+ return nil
case paymentorder.FieldRefundReason:
m.ClearRefundReason()
return nil
@@ -18129,6 +18199,9 @@ func (m *PaymentOrderMutation) ResetField(name string) error {
case paymentorder.FieldProviderInstanceID:
m.ResetProviderInstanceID()
return nil
+ case paymentorder.FieldProviderKey:
+ m.ResetProviderKey()
+ return nil
case paymentorder.FieldStatus:
m.ResetStatus()
return nil
diff --git a/backend/ent/paymentorder.go b/backend/ent/paymentorder.go
index 6ea3e709..a58823ee 100644
--- a/backend/ent/paymentorder.go
+++ b/backend/ent/paymentorder.go
@@ -56,6 +56,8 @@ type PaymentOrder struct {
SubscriptionDays *int `json:"subscription_days,omitempty"`
// ProviderInstanceID holds the value of the "provider_instance_id" field.
ProviderInstanceID *string `json:"provider_instance_id,omitempty"`
+ // ProviderKey holds the value of the "provider_key" field.
+ ProviderKey *string `json:"provider_key,omitempty"`
// Status holds the value of the "status" field.
Status string `json:"status,omitempty"`
// RefundAmount holds the value of the "refund_amount" field.
@@ -129,7 +131,7 @@ func (*PaymentOrder) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullFloat64)
case paymentorder.FieldID, paymentorder.FieldUserID, paymentorder.FieldPlanID, paymentorder.FieldSubscriptionGroupID, paymentorder.FieldSubscriptionDays:
values[i] = new(sql.NullInt64)
- case paymentorder.FieldUserEmail, paymentorder.FieldUserName, paymentorder.FieldUserNotes, paymentorder.FieldRechargeCode, paymentorder.FieldOutTradeNo, paymentorder.FieldPaymentType, paymentorder.FieldPaymentTradeNo, paymentorder.FieldPayURL, paymentorder.FieldQrCode, paymentorder.FieldQrCodeImg, paymentorder.FieldOrderType, paymentorder.FieldProviderInstanceID, paymentorder.FieldStatus, paymentorder.FieldRefundReason, paymentorder.FieldRefundRequestReason, paymentorder.FieldRefundRequestedBy, paymentorder.FieldFailedReason, paymentorder.FieldClientIP, paymentorder.FieldSrcHost, paymentorder.FieldSrcURL:
+ case paymentorder.FieldUserEmail, paymentorder.FieldUserName, paymentorder.FieldUserNotes, paymentorder.FieldRechargeCode, paymentorder.FieldOutTradeNo, paymentorder.FieldPaymentType, paymentorder.FieldPaymentTradeNo, paymentorder.FieldPayURL, paymentorder.FieldQrCode, paymentorder.FieldQrCodeImg, paymentorder.FieldOrderType, paymentorder.FieldProviderInstanceID, paymentorder.FieldProviderKey, paymentorder.FieldStatus, paymentorder.FieldRefundReason, paymentorder.FieldRefundRequestReason, paymentorder.FieldRefundRequestedBy, paymentorder.FieldFailedReason, paymentorder.FieldClientIP, paymentorder.FieldSrcHost, paymentorder.FieldSrcURL:
values[i] = new(sql.NullString)
case paymentorder.FieldRefundAt, paymentorder.FieldRefundRequestedAt, paymentorder.FieldExpiresAt, paymentorder.FieldPaidAt, paymentorder.FieldCompletedAt, paymentorder.FieldFailedAt, paymentorder.FieldCreatedAt, paymentorder.FieldUpdatedAt:
values[i] = new(sql.NullTime)
@@ -276,6 +278,13 @@ func (_m *PaymentOrder) assignValues(columns []string, values []any) error {
_m.ProviderInstanceID = new(string)
*_m.ProviderInstanceID = value.String
}
+ case paymentorder.FieldProviderKey:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_key", values[i])
+ } else if value.Valid {
+ _m.ProviderKey = new(string)
+ *_m.ProviderKey = value.String
+ }
case paymentorder.FieldStatus:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field status", values[i])
@@ -508,6 +517,11 @@ func (_m *PaymentOrder) String() string {
builder.WriteString(*v)
}
builder.WriteString(", ")
+ if v := _m.ProviderKey; v != nil {
+ builder.WriteString("provider_key=")
+ builder.WriteString(*v)
+ }
+ builder.WriteString(", ")
builder.WriteString("status=")
builder.WriteString(_m.Status)
builder.WriteString(", ")
diff --git a/backend/ent/paymentorder/paymentorder.go b/backend/ent/paymentorder/paymentorder.go
index 4467b2b6..af9b1422 100644
--- a/backend/ent/paymentorder/paymentorder.go
+++ b/backend/ent/paymentorder/paymentorder.go
@@ -52,6 +52,8 @@ const (
FieldSubscriptionDays = "subscription_days"
// FieldProviderInstanceID holds the string denoting the provider_instance_id field in the database.
FieldProviderInstanceID = "provider_instance_id"
+ // FieldProviderKey holds the string denoting the provider_key field in the database.
+ FieldProviderKey = "provider_key"
// FieldStatus holds the string denoting the status field in the database.
FieldStatus = "status"
// FieldRefundAmount holds the string denoting the refund_amount field in the database.
@@ -123,6 +125,7 @@ var Columns = []string{
FieldSubscriptionGroupID,
FieldSubscriptionDays,
FieldProviderInstanceID,
+ FieldProviderKey,
FieldStatus,
FieldRefundAmount,
FieldRefundReason,
@@ -176,6 +179,8 @@ var (
OrderTypeValidator func(string) error
// ProviderInstanceIDValidator is a validator for the "provider_instance_id" field. It is called by the builders before save.
ProviderInstanceIDValidator func(string) error
+ // ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
+ ProviderKeyValidator func(string) error
// DefaultStatus holds the default value on creation for the "status" field.
DefaultStatus string
// StatusValidator is a validator for the "status" field. It is called by the builders before save.
@@ -301,6 +306,11 @@ func ByProviderInstanceID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldProviderInstanceID, opts...).ToFunc()
}
+// ByProviderKey orders the results by the provider_key field.
+func ByProviderKey(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProviderKey, opts...).ToFunc()
+}
+
// ByStatus orders the results by the status field.
func ByStatus(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldStatus, opts...).ToFunc()
diff --git a/backend/ent/paymentorder/where.go b/backend/ent/paymentorder/where.go
index 78520fac..0f6b74a0 100644
--- a/backend/ent/paymentorder/where.go
+++ b/backend/ent/paymentorder/where.go
@@ -150,6 +150,11 @@ func ProviderInstanceID(v string) predicate.PaymentOrder {
return predicate.PaymentOrder(sql.FieldEQ(FieldProviderInstanceID, v))
}
+// ProviderKey applies equality check predicate on the "provider_key" field. It's identical to ProviderKeyEQ.
+func ProviderKey(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldProviderKey, v))
+}
+
// Status applies equality check predicate on the "status" field. It's identical to StatusEQ.
func Status(v string) predicate.PaymentOrder {
return predicate.PaymentOrder(sql.FieldEQ(FieldStatus, v))
@@ -1360,6 +1365,81 @@ func ProviderInstanceIDContainsFold(v string) predicate.PaymentOrder {
return predicate.PaymentOrder(sql.FieldContainsFold(FieldProviderInstanceID, v))
}
+// ProviderKeyEQ applies the EQ predicate on the "provider_key" field.
+func ProviderKeyEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldProviderKey, v))
+}
+
+// ProviderKeyNEQ applies the NEQ predicate on the "provider_key" field.
+func ProviderKeyNEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNEQ(FieldProviderKey, v))
+}
+
+// ProviderKeyIn applies the In predicate on the "provider_key" field.
+func ProviderKeyIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIn(FieldProviderKey, vs...))
+}
+
+// ProviderKeyNotIn applies the NotIn predicate on the "provider_key" field.
+func ProviderKeyNotIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotIn(FieldProviderKey, vs...))
+}
+
+// ProviderKeyGT applies the GT predicate on the "provider_key" field.
+func ProviderKeyGT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGT(FieldProviderKey, v))
+}
+
+// ProviderKeyGTE applies the GTE predicate on the "provider_key" field.
+func ProviderKeyGTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGTE(FieldProviderKey, v))
+}
+
+// ProviderKeyLT applies the LT predicate on the "provider_key" field.
+func ProviderKeyLT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLT(FieldProviderKey, v))
+}
+
+// ProviderKeyLTE applies the LTE predicate on the "provider_key" field.
+func ProviderKeyLTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLTE(FieldProviderKey, v))
+}
+
+// ProviderKeyContains applies the Contains predicate on the "provider_key" field.
+func ProviderKeyContains(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContains(FieldProviderKey, v))
+}
+
+// ProviderKeyHasPrefix applies the HasPrefix predicate on the "provider_key" field.
+func ProviderKeyHasPrefix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasPrefix(FieldProviderKey, v))
+}
+
+// ProviderKeyHasSuffix applies the HasSuffix predicate on the "provider_key" field.
+func ProviderKeyHasSuffix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasSuffix(FieldProviderKey, v))
+}
+
+// ProviderKeyIsNil applies the IsNil predicate on the "provider_key" field.
+func ProviderKeyIsNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIsNull(FieldProviderKey))
+}
+
+// ProviderKeyNotNil applies the NotNil predicate on the "provider_key" field.
+func ProviderKeyNotNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotNull(FieldProviderKey))
+}
+
+// ProviderKeyEqualFold applies the EqualFold predicate on the "provider_key" field.
+func ProviderKeyEqualFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEqualFold(FieldProviderKey, v))
+}
+
+// ProviderKeyContainsFold applies the ContainsFold predicate on the "provider_key" field.
+func ProviderKeyContainsFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContainsFold(FieldProviderKey, v))
+}
+
// StatusEQ applies the EQ predicate on the "status" field.
func StatusEQ(v string) predicate.PaymentOrder {
return predicate.PaymentOrder(sql.FieldEQ(FieldStatus, v))
diff --git a/backend/ent/paymentorder_create.go b/backend/ent/paymentorder_create.go
index 03098339..497ba52c 100644
--- a/backend/ent/paymentorder_create.go
+++ b/backend/ent/paymentorder_create.go
@@ -225,6 +225,20 @@ func (_c *PaymentOrderCreate) SetNillableProviderInstanceID(v *string) *PaymentO
return _c
}
+// SetProviderKey sets the "provider_key" field.
+func (_c *PaymentOrderCreate) SetProviderKey(v string) *PaymentOrderCreate {
+ _c.mutation.SetProviderKey(v)
+ return _c
+}
+
+// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
+func (_c *PaymentOrderCreate) SetNillableProviderKey(v *string) *PaymentOrderCreate {
+ if v != nil {
+ _c.SetProviderKey(*v)
+ }
+ return _c
+}
+
// SetStatus sets the "status" field.
func (_c *PaymentOrderCreate) SetStatus(v string) *PaymentOrderCreate {
_c.mutation.SetStatus(v)
@@ -602,6 +616,11 @@ func (_c *PaymentOrderCreate) check() error {
return &ValidationError{Name: "provider_instance_id", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_instance_id": %w`, err)}
}
}
+ if v, ok := _c.mutation.ProviderKey(); ok {
+ if err := paymentorder.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_key": %w`, err)}
+ }
+ }
if _, ok := _c.mutation.Status(); !ok {
return &ValidationError{Name: "status", err: errors.New(`ent: missing required field "PaymentOrder.status"`)}
}
@@ -748,6 +767,10 @@ func (_c *PaymentOrderCreate) createSpec() (*PaymentOrder, *sqlgraph.CreateSpec)
_spec.SetField(paymentorder.FieldProviderInstanceID, field.TypeString, value)
_node.ProviderInstanceID = &value
}
+ if value, ok := _c.mutation.ProviderKey(); ok {
+ _spec.SetField(paymentorder.FieldProviderKey, field.TypeString, value)
+ _node.ProviderKey = &value
+ }
if value, ok := _c.mutation.Status(); ok {
_spec.SetField(paymentorder.FieldStatus, field.TypeString, value)
_node.Status = value
@@ -1201,6 +1224,24 @@ func (u *PaymentOrderUpsert) ClearProviderInstanceID() *PaymentOrderUpsert {
return u
}
+// SetProviderKey sets the "provider_key" field.
+func (u *PaymentOrderUpsert) SetProviderKey(v string) *PaymentOrderUpsert {
+ u.Set(paymentorder.FieldProviderKey, v)
+ return u
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *PaymentOrderUpsert) UpdateProviderKey() *PaymentOrderUpsert {
+ u.SetExcluded(paymentorder.FieldProviderKey)
+ return u
+}
+
+// ClearProviderKey clears the value of the "provider_key" field.
+func (u *PaymentOrderUpsert) ClearProviderKey() *PaymentOrderUpsert {
+ u.SetNull(paymentorder.FieldProviderKey)
+ return u
+}
+
// SetStatus sets the "status" field.
func (u *PaymentOrderUpsert) SetStatus(v string) *PaymentOrderUpsert {
u.Set(paymentorder.FieldStatus, v)
@@ -1880,6 +1921,27 @@ func (u *PaymentOrderUpsertOne) ClearProviderInstanceID() *PaymentOrderUpsertOne
})
}
+// SetProviderKey sets the "provider_key" field.
+func (u *PaymentOrderUpsertOne) SetProviderKey(v string) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetProviderKey(v)
+ })
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *PaymentOrderUpsertOne) UpdateProviderKey() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateProviderKey()
+ })
+}
+
+// ClearProviderKey clears the value of the "provider_key" field.
+func (u *PaymentOrderUpsertOne) ClearProviderKey() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearProviderKey()
+ })
+}
+
// SetStatus sets the "status" field.
func (u *PaymentOrderUpsertOne) SetStatus(v string) *PaymentOrderUpsertOne {
return u.Update(func(s *PaymentOrderUpsert) {
@@ -2770,6 +2832,27 @@ func (u *PaymentOrderUpsertBulk) ClearProviderInstanceID() *PaymentOrderUpsertBu
})
}
+// SetProviderKey sets the "provider_key" field.
+func (u *PaymentOrderUpsertBulk) SetProviderKey(v string) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetProviderKey(v)
+ })
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *PaymentOrderUpsertBulk) UpdateProviderKey() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateProviderKey()
+ })
+}
+
+// ClearProviderKey clears the value of the "provider_key" field.
+func (u *PaymentOrderUpsertBulk) ClearProviderKey() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearProviderKey()
+ })
+}
+
// SetStatus sets the "status" field.
func (u *PaymentOrderUpsertBulk) SetStatus(v string) *PaymentOrderUpsertBulk {
return u.Update(func(s *PaymentOrderUpsert) {
diff --git a/backend/ent/paymentorder_update.go b/backend/ent/paymentorder_update.go
index 5978fc29..9a901415 100644
--- a/backend/ent/paymentorder_update.go
+++ b/backend/ent/paymentorder_update.go
@@ -385,6 +385,26 @@ func (_u *PaymentOrderUpdate) ClearProviderInstanceID() *PaymentOrderUpdate {
return _u
}
+// SetProviderKey sets the "provider_key" field.
+func (_u *PaymentOrderUpdate) SetProviderKey(v string) *PaymentOrderUpdate {
+ _u.mutation.SetProviderKey(v)
+ return _u
+}
+
+// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
+func (_u *PaymentOrderUpdate) SetNillableProviderKey(v *string) *PaymentOrderUpdate {
+ if v != nil {
+ _u.SetProviderKey(*v)
+ }
+ return _u
+}
+
+// ClearProviderKey clears the value of the "provider_key" field.
+func (_u *PaymentOrderUpdate) ClearProviderKey() *PaymentOrderUpdate {
+ _u.mutation.ClearProviderKey()
+ return _u
+}
+
// SetStatus sets the "status" field.
func (_u *PaymentOrderUpdate) SetStatus(v string) *PaymentOrderUpdate {
_u.mutation.SetStatus(v)
@@ -776,6 +796,11 @@ func (_u *PaymentOrderUpdate) check() error {
return &ValidationError{Name: "provider_instance_id", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_instance_id": %w`, err)}
}
}
+ if v, ok := _u.mutation.ProviderKey(); ok {
+ if err := paymentorder.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_key": %w`, err)}
+ }
+ }
if v, ok := _u.mutation.Status(); ok {
if err := paymentorder.StatusValidator(v); err != nil {
return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.status": %w`, err)}
@@ -910,6 +935,12 @@ func (_u *PaymentOrderUpdate) sqlSave(ctx context.Context) (_node int, err error
if _u.mutation.ProviderInstanceIDCleared() {
_spec.ClearField(paymentorder.FieldProviderInstanceID, field.TypeString)
}
+ if value, ok := _u.mutation.ProviderKey(); ok {
+ _spec.SetField(paymentorder.FieldProviderKey, field.TypeString, value)
+ }
+ if _u.mutation.ProviderKeyCleared() {
+ _spec.ClearField(paymentorder.FieldProviderKey, field.TypeString)
+ }
if value, ok := _u.mutation.Status(); ok {
_spec.SetField(paymentorder.FieldStatus, field.TypeString, value)
}
@@ -1399,6 +1430,26 @@ func (_u *PaymentOrderUpdateOne) ClearProviderInstanceID() *PaymentOrderUpdateOn
return _u
}
+// SetProviderKey sets the "provider_key" field.
+func (_u *PaymentOrderUpdateOne) SetProviderKey(v string) *PaymentOrderUpdateOne {
+ _u.mutation.SetProviderKey(v)
+ return _u
+}
+
+// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
+func (_u *PaymentOrderUpdateOne) SetNillableProviderKey(v *string) *PaymentOrderUpdateOne {
+ if v != nil {
+ _u.SetProviderKey(*v)
+ }
+ return _u
+}
+
+// ClearProviderKey clears the value of the "provider_key" field.
+func (_u *PaymentOrderUpdateOne) ClearProviderKey() *PaymentOrderUpdateOne {
+ _u.mutation.ClearProviderKey()
+ return _u
+}
+
// SetStatus sets the "status" field.
func (_u *PaymentOrderUpdateOne) SetStatus(v string) *PaymentOrderUpdateOne {
_u.mutation.SetStatus(v)
@@ -1803,6 +1854,11 @@ func (_u *PaymentOrderUpdateOne) check() error {
return &ValidationError{Name: "provider_instance_id", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_instance_id": %w`, err)}
}
}
+ if v, ok := _u.mutation.ProviderKey(); ok {
+ if err := paymentorder.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_key": %w`, err)}
+ }
+ }
if v, ok := _u.mutation.Status(); ok {
if err := paymentorder.StatusValidator(v); err != nil {
return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.status": %w`, err)}
@@ -1954,6 +2010,12 @@ func (_u *PaymentOrderUpdateOne) sqlSave(ctx context.Context) (_node *PaymentOrd
if _u.mutation.ProviderInstanceIDCleared() {
_spec.ClearField(paymentorder.FieldProviderInstanceID, field.TypeString)
}
+ if value, ok := _u.mutation.ProviderKey(); ok {
+ _spec.SetField(paymentorder.FieldProviderKey, field.TypeString, value)
+ }
+ if _u.mutation.ProviderKeyCleared() {
+ _spec.ClearField(paymentorder.FieldProviderKey, field.TypeString)
+ }
if value, ok := _u.mutation.Status(); ok {
_spec.SetField(paymentorder.FieldStatus, field.TypeString, value)
}
diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go
index 268e9ddb..b7118ac9 100644
--- a/backend/ent/runtime/runtime.go
+++ b/backend/ent/runtime/runtime.go
@@ -723,38 +723,42 @@ func init() {
paymentorderDescProviderInstanceID := paymentorderFields[18].Descriptor()
// paymentorder.ProviderInstanceIDValidator is a validator for the "provider_instance_id" field. It is called by the builders before save.
paymentorder.ProviderInstanceIDValidator = paymentorderDescProviderInstanceID.Validators[0].(func(string) error)
+ // paymentorderDescProviderKey is the schema descriptor for provider_key field.
+ paymentorderDescProviderKey := paymentorderFields[19].Descriptor()
+ // paymentorder.ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
+ paymentorder.ProviderKeyValidator = paymentorderDescProviderKey.Validators[0].(func(string) error)
// paymentorderDescStatus is the schema descriptor for status field.
- paymentorderDescStatus := paymentorderFields[19].Descriptor()
+ paymentorderDescStatus := paymentorderFields[20].Descriptor()
// paymentorder.DefaultStatus holds the default value on creation for the status field.
paymentorder.DefaultStatus = paymentorderDescStatus.Default.(string)
// paymentorder.StatusValidator is a validator for the "status" field. It is called by the builders before save.
paymentorder.StatusValidator = paymentorderDescStatus.Validators[0].(func(string) error)
// paymentorderDescRefundAmount is the schema descriptor for refund_amount field.
- paymentorderDescRefundAmount := paymentorderFields[20].Descriptor()
+ paymentorderDescRefundAmount := paymentorderFields[21].Descriptor()
// paymentorder.DefaultRefundAmount holds the default value on creation for the refund_amount field.
paymentorder.DefaultRefundAmount = paymentorderDescRefundAmount.Default.(float64)
// paymentorderDescForceRefund is the schema descriptor for force_refund field.
- paymentorderDescForceRefund := paymentorderFields[23].Descriptor()
+ paymentorderDescForceRefund := paymentorderFields[24].Descriptor()
// paymentorder.DefaultForceRefund holds the default value on creation for the force_refund field.
paymentorder.DefaultForceRefund = paymentorderDescForceRefund.Default.(bool)
// paymentorderDescRefundRequestedBy is the schema descriptor for refund_requested_by field.
- paymentorderDescRefundRequestedBy := paymentorderFields[26].Descriptor()
+ paymentorderDescRefundRequestedBy := paymentorderFields[27].Descriptor()
// paymentorder.RefundRequestedByValidator is a validator for the "refund_requested_by" field. It is called by the builders before save.
paymentorder.RefundRequestedByValidator = paymentorderDescRefundRequestedBy.Validators[0].(func(string) error)
// paymentorderDescClientIP is the schema descriptor for client_ip field.
- paymentorderDescClientIP := paymentorderFields[32].Descriptor()
+ paymentorderDescClientIP := paymentorderFields[33].Descriptor()
// paymentorder.ClientIPValidator is a validator for the "client_ip" field. It is called by the builders before save.
paymentorder.ClientIPValidator = paymentorderDescClientIP.Validators[0].(func(string) error)
// paymentorderDescSrcHost is the schema descriptor for src_host field.
- paymentorderDescSrcHost := paymentorderFields[33].Descriptor()
+ paymentorderDescSrcHost := paymentorderFields[34].Descriptor()
// paymentorder.SrcHostValidator is a validator for the "src_host" field. It is called by the builders before save.
paymentorder.SrcHostValidator = paymentorderDescSrcHost.Validators[0].(func(string) error)
// paymentorderDescCreatedAt is the schema descriptor for created_at field.
- paymentorderDescCreatedAt := paymentorderFields[35].Descriptor()
+ paymentorderDescCreatedAt := paymentorderFields[36].Descriptor()
// paymentorder.DefaultCreatedAt holds the default value on creation for the created_at field.
paymentorder.DefaultCreatedAt = paymentorderDescCreatedAt.Default.(func() time.Time)
// paymentorderDescUpdatedAt is the schema descriptor for updated_at field.
- paymentorderDescUpdatedAt := paymentorderFields[36].Descriptor()
+ paymentorderDescUpdatedAt := paymentorderFields[37].Descriptor()
// paymentorder.DefaultUpdatedAt holds the default value on creation for the updated_at field.
paymentorder.DefaultUpdatedAt = paymentorderDescUpdatedAt.Default.(func() time.Time)
// paymentorder.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
diff --git a/backend/ent/schema/payment_order.go b/backend/ent/schema/payment_order.go
index a9576d2a..64378de1 100644
--- a/backend/ent/schema/payment_order.go
+++ b/backend/ent/schema/payment_order.go
@@ -91,6 +91,10 @@ func (PaymentOrder) Fields() []ent.Field {
Optional().
Nillable().
MaxLen(64),
+ field.String("provider_key").
+ Optional().
+ Nillable().
+ MaxLen(30),
// 状态
field.String("status").
diff --git a/backend/internal/service/payment_fulfillment.go b/backend/internal/service/payment_fulfillment.go
index 519455f0..83bac21d 100644
--- a/backend/internal/service/payment_fulfillment.go
+++ b/backend/internal/service/payment_fulfillment.go
@@ -45,7 +45,7 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo
if inst, instErr := s.getOrderProviderInstance(ctx, o); instErr == nil && inst != nil {
instanceProviderKey = inst.ProviderKey
}
- expectedProviderKey := expectedNotificationProviderKey(s.registry, o.PaymentType, instanceProviderKey)
+ expectedProviderKey := expectedNotificationProviderKey(s.registry, o.PaymentType, psStringValue(o.ProviderKey), instanceProviderKey)
if expectedProviderKey != "" && strings.TrimSpace(pk) != "" && !strings.EqualFold(expectedProviderKey, strings.TrimSpace(pk)) {
s.writeAuditLog(ctx, o.ID, "PAYMENT_PROVIDER_MISMATCH", pk, map[string]any{
"expectedProvider": expectedProviderKey,
@@ -69,10 +69,13 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo
return s.toPaid(ctx, o, tradeNo, paid, pk)
}
-func expectedNotificationProviderKey(registry *payment.Registry, orderPaymentType string, instanceProviderKey string) string {
+func expectedNotificationProviderKey(registry *payment.Registry, orderPaymentType string, orderProviderKey string, instanceProviderKey string) string {
if key := strings.TrimSpace(instanceProviderKey); key != "" {
return key
}
+ if key := strings.TrimSpace(orderProviderKey); key != "" {
+ return key
+ }
if registry != nil {
if key := strings.TrimSpace(registry.GetProviderKey(payment.PaymentType(orderPaymentType))); key != "" {
return key
diff --git a/backend/internal/service/payment_fulfillment_test.go b/backend/internal/service/payment_fulfillment_test.go
index 4cc00301..712129b0 100644
--- a/backend/internal/service/payment_fulfillment_test.go
+++ b/backend/internal/service/payment_fulfillment_test.go
@@ -198,7 +198,7 @@ func TestExpectedNotificationProviderKeyPrefersOrderInstanceProvider(t *testing.
assert.Equal(t,
payment.TypeEasyPay,
- expectedNotificationProviderKey(registry, payment.TypeAlipay, payment.TypeEasyPay),
+ expectedNotificationProviderKey(registry, payment.TypeAlipay, "", payment.TypeEasyPay),
)
}
@@ -213,7 +213,7 @@ func TestExpectedNotificationProviderKeyUsesRegistryMappingForLegacyOrders(t *te
assert.Equal(t,
payment.TypeEasyPay,
- expectedNotificationProviderKey(registry, payment.TypeAlipay, ""),
+ expectedNotificationProviderKey(registry, payment.TypeAlipay, "", ""),
)
}
@@ -222,6 +222,21 @@ func TestExpectedNotificationProviderKeyFallsBackToPaymentType(t *testing.T) {
assert.Equal(t,
payment.TypeWxpay,
- expectedNotificationProviderKey(nil, payment.TypeWxpay, ""),
+ expectedNotificationProviderKey(nil, payment.TypeWxpay, "", ""),
+ )
+}
+
+func TestExpectedNotificationProviderKeyPrefersOrderSnapshotProviderKey(t *testing.T) {
+ t.Parallel()
+
+ registry := payment.NewRegistry()
+ registry.Register(paymentFulfillmentTestProvider{
+ key: payment.TypeAlipay,
+ supportedTypes: []payment.PaymentType{payment.TypeAlipay},
+ })
+
+ assert.Equal(t,
+ payment.TypeEasyPay,
+ expectedNotificationProviderKey(registry, payment.TypeAlipay, payment.TypeEasyPay, ""),
)
}
diff --git a/backend/internal/service/payment_order.go b/backend/internal/service/payment_order.go
index fa256be7..4b9b1872 100644
--- a/backend/internal/service/payment_order.go
+++ b/backend/internal/service/payment_order.go
@@ -251,7 +251,13 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen
slog.Error("[PaymentService] CreatePayment failed", "provider", sel.ProviderKey, "instance", sel.InstanceID, "error", err)
return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", fmt.Sprintf("payment gateway error: %s", err.Error()))
}
- _, err = s.entClient.PaymentOrder.UpdateOneID(order.ID).SetNillablePaymentTradeNo(psNilIfEmpty(pr.TradeNo)).SetNillablePayURL(psNilIfEmpty(pr.PayURL)).SetNillableQrCode(psNilIfEmpty(pr.QRCode)).SetNillableProviderInstanceID(psNilIfEmpty(sel.InstanceID)).Save(ctx)
+ _, err = s.entClient.PaymentOrder.UpdateOneID(order.ID).
+ SetNillablePaymentTradeNo(psNilIfEmpty(pr.TradeNo)).
+ SetNillablePayURL(psNilIfEmpty(pr.PayURL)).
+ SetNillableQrCode(psNilIfEmpty(pr.QRCode)).
+ SetNillableProviderInstanceID(psNilIfEmpty(sel.InstanceID)).
+ SetNillableProviderKey(psNilIfEmpty(sel.ProviderKey)).
+ Save(ctx)
if err != nil {
return nil, fmt.Errorf("update order with payment details: %w", err)
}
diff --git a/backend/internal/service/payment_order_lifecycle.go b/backend/internal/service/payment_order_lifecycle.go
index 80147180..f804eb8b 100644
--- a/backend/internal/service/payment_order_lifecycle.go
+++ b/backend/internal/service/payment_order_lifecycle.go
@@ -5,6 +5,7 @@ import (
"fmt"
"log/slog"
"strconv"
+ "strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
@@ -241,7 +242,10 @@ func (s *PaymentService) getOrderProvider(ctx context.Context, o *dbent.PaymentO
if err == nil {
cfg, err := s.loadBalancer.GetInstanceConfig(ctx, instID)
if err == nil {
- providerKey := s.registry.GetProviderKey(o.PaymentType)
+ providerKey := strings.TrimSpace(psStringValue(o.ProviderKey))
+ if providerKey == "" {
+ providerKey = s.registry.GetProviderKey(o.PaymentType)
+ }
if providerKey == "" {
providerKey = o.PaymentType
}
@@ -255,3 +259,10 @@ func (s *PaymentService) getOrderProvider(ctx context.Context, o *dbent.PaymentO
s.EnsureProviders(ctx)
return s.registry.GetProvider(o.PaymentType)
}
+
+func psStringValue(value *string) string {
+ if value == nil {
+ return ""
+ }
+ return *value
+}
diff --git a/backend/migrations/112_add_payment_order_provider_key_snapshot.sql b/backend/migrations/112_add_payment_order_provider_key_snapshot.sql
new file mode 100644
index 00000000..7ec19ae3
--- /dev/null
+++ b/backend/migrations/112_add_payment_order_provider_key_snapshot.sql
@@ -0,0 +1,10 @@
+ALTER TABLE payment_orders ADD COLUMN provider_key VARCHAR(30);
+
+UPDATE payment_orders
+SET provider_key = (
+ SELECT provider_key
+ FROM payment_provider_instances
+ WHERE CAST(id AS TEXT) = payment_orders.provider_instance_id
+)
+WHERE provider_key IS NULL
+ AND provider_instance_id IS NOT NULL;
--
GitLab
From 9bebf1c1a611e37c68a127ffee07a670460a6a09 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Mon, 20 Apr 2026 20:53:46 +0800
Subject: [PATCH 052/261] feat: resolve payment results by resume token
---
backend/internal/handler/payment_handler.go | 29 +++++
backend/internal/server/routes/payment.go | 1 +
.../internal/service/payment_resume_lookup.go | 35 ++++++
.../service/payment_resume_lookup_test.go | 118 ++++++++++++++++++
frontend/src/api/payment.ts | 5 +
frontend/src/views/user/PaymentResultView.vue | 12 +-
.../user/__tests__/PaymentResultView.spec.ts | 26 ++++
7 files changed, 225 insertions(+), 1 deletion(-)
create mode 100644 backend/internal/service/payment_resume_lookup.go
create mode 100644 backend/internal/service/payment_resume_lookup_test.go
diff --git a/backend/internal/handler/payment_handler.go b/backend/internal/handler/payment_handler.go
index d5c3d7c8..6440cdfd 100644
--- a/backend/internal/handler/payment_handler.go
+++ b/backend/internal/handler/payment_handler.go
@@ -357,6 +357,10 @@ type VerifyOrderRequest struct {
OutTradeNo string `json:"out_trade_no" binding:"required"`
}
+type ResolveOrderByResumeTokenRequest struct {
+ ResumeToken string `json:"resume_token" binding:"required"`
+}
+
// VerifyOrder actively queries the upstream payment provider to check
// if payment was made, and processes it if so.
// POST /api/v1/payment/orders/verify
@@ -417,6 +421,31 @@ func (h *PaymentHandler) VerifyOrderPublic(c *gin.Context) {
})
}
+// ResolveOrderPublicByResumeToken resolves a payment order from a signed resume token.
+// POST /api/v1/payment/public/orders/resolve
+func (h *PaymentHandler) ResolveOrderPublicByResumeToken(c *gin.Context) {
+ var req ResolveOrderByResumeTokenRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ order, err := h.paymentService.GetPublicOrderByResumeToken(c.Request.Context(), req.ResumeToken)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, PublicOrderResult{
+ ID: order.ID,
+ OutTradeNo: order.OutTradeNo,
+ Amount: order.Amount,
+ PayAmount: order.PayAmount,
+ PaymentType: order.PaymentType,
+ OrderType: order.OrderType,
+ Status: order.Status,
+ })
+}
+
// requireAuth extracts the authenticated subject from the context.
// Returns the subject and true on success; on failure it writes an Unauthorized response and returns false.
func requireAuth(c *gin.Context) (middleware2.AuthSubject, bool) {
diff --git a/backend/internal/server/routes/payment.go b/backend/internal/server/routes/payment.go
index 23bd58ad..dff14a70 100644
--- a/backend/internal/server/routes/payment.go
+++ b/backend/internal/server/routes/payment.go
@@ -49,6 +49,7 @@ func RegisterPaymentRoutes(
public := v1.Group("/payment/public")
{
public.POST("/orders/verify", paymentHandler.VerifyOrderPublic)
+ public.POST("/orders/resolve", paymentHandler.ResolveOrderPublicByResumeToken)
}
// --- Webhook endpoints (no auth) ---
diff --git a/backend/internal/service/payment_resume_lookup.go b/backend/internal/service/payment_resume_lookup.go
new file mode 100644
index 00000000..493ca325
--- /dev/null
+++ b/backend/internal/service/payment_resume_lookup.go
@@ -0,0 +1,35 @@
+package service
+
+import (
+ "context"
+ "fmt"
+ "strings"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+)
+
+func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token string) (*dbent.PaymentOrder, error) {
+ claims, err := s.paymentResume().ParseToken(strings.TrimSpace(token))
+ if err != nil {
+ return nil, err
+ }
+
+ order, err := s.entClient.PaymentOrder.Get(ctx, claims.OrderID)
+ if err != nil {
+ return nil, fmt.Errorf("get order by resume token: %w", err)
+ }
+ if claims.UserID > 0 && order.UserID != claims.UserID {
+ return nil, fmt.Errorf("resume token user mismatch")
+ }
+ if claims.ProviderInstanceID != "" && strings.TrimSpace(psStringValue(order.ProviderInstanceID)) != claims.ProviderInstanceID {
+ return nil, fmt.Errorf("resume token provider instance mismatch")
+ }
+ if claims.ProviderKey != "" && strings.TrimSpace(psStringValue(order.ProviderKey)) != claims.ProviderKey {
+ return nil, fmt.Errorf("resume token provider key mismatch")
+ }
+ if claims.PaymentType != "" && strings.TrimSpace(order.PaymentType) != claims.PaymentType {
+ return nil, fmt.Errorf("resume token payment type mismatch")
+ }
+
+ return order, nil
+}
diff --git a/backend/internal/service/payment_resume_lookup_test.go b/backend/internal/service/payment_resume_lookup_test.go
new file mode 100644
index 00000000..d411398e
--- /dev/null
+++ b/backend/internal/service/payment_resume_lookup_test.go
@@ -0,0 +1,118 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ "github.com/stretchr/testify/require"
+)
+
+func TestGetPublicOrderByResumeTokenReturnsMatchingOrder(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ user, err := client.User.Create().
+ SetEmail("resume@example.com").
+ SetPasswordHash("hash").
+ SetUsername("resume-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ instanceID := "12"
+ providerKey := payment.TypeEasyPay
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(88).
+ SetPayAmount(88).
+ SetFeeRate(0).
+ SetRechargeCode("RESUME-ORDER").
+ SetOutTradeNo("sub2_resume_lookup").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("trade-1").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusPending).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ SetProviderInstanceID(instanceID).
+ SetProviderKey(providerKey).
+ Save(ctx)
+ require.NoError(t, err)
+
+ resumeSvc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
+ token, err := resumeSvc.CreateToken(ResumeTokenClaims{
+ OrderID: order.ID,
+ UserID: user.ID,
+ ProviderInstanceID: instanceID,
+ ProviderKey: providerKey,
+ PaymentType: payment.TypeAlipay,
+ CanonicalReturnURL: "https://app.example.com/payment/result",
+ })
+ require.NoError(t, err)
+
+ svc := &PaymentService{
+ entClient: client,
+ resumeService: resumeSvc,
+ }
+
+ got, err := svc.GetPublicOrderByResumeToken(ctx, token)
+ require.NoError(t, err)
+ require.Equal(t, order.ID, got.ID)
+}
+
+func TestGetPublicOrderByResumeTokenRejectsSnapshotMismatch(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ user, err := client.User.Create().
+ SetEmail("resume-mismatch@example.com").
+ SetPasswordHash("hash").
+ SetUsername("resume-mismatch-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(88).
+ SetPayAmount(88).
+ SetFeeRate(0).
+ SetRechargeCode("RESUME-MISMATCH").
+ SetOutTradeNo("sub2_resume_lookup_mismatch").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("trade-2").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusPending).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ SetProviderInstanceID("12").
+ SetProviderKey(payment.TypeEasyPay).
+ Save(ctx)
+ require.NoError(t, err)
+
+ resumeSvc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
+ token, err := resumeSvc.CreateToken(ResumeTokenClaims{
+ OrderID: order.ID,
+ UserID: user.ID,
+ ProviderInstanceID: "99",
+ ProviderKey: payment.TypeEasyPay,
+ PaymentType: payment.TypeAlipay,
+ CanonicalReturnURL: "https://app.example.com/payment/result",
+ })
+ require.NoError(t, err)
+
+ svc := &PaymentService{
+ entClient: client,
+ resumeService: resumeSvc,
+ }
+
+ _, err = svc.GetPublicOrderByResumeToken(ctx, token)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "resume token")
+}
diff --git a/frontend/src/api/payment.ts b/frontend/src/api/payment.ts
index 5cedb107..91b16866 100644
--- a/frontend/src/api/payment.ts
+++ b/frontend/src/api/payment.ts
@@ -72,6 +72,11 @@ export const paymentAPI = {
return apiClient.post('/payment/public/orders/verify', { out_trade_no: outTradeNo })
},
+ /** Resolve an order from a signed resume token without auth */
+ resolveOrderPublicByResumeToken(resumeToken: string) {
+ return apiClient.post('/payment/public/orders/resolve', { resume_token: resumeToken })
+ },
+
/** Request a refund for a completed order */
requestRefund(id: number, data: { reason: string }) {
return apiClient.post(`/payment/orders/${id}/refund-request`, data)
diff --git a/frontend/src/views/user/PaymentResultView.vue b/frontend/src/views/user/PaymentResultView.vue
index e1db3ce2..9687d1c7 100644
--- a/frontend/src/views/user/PaymentResultView.vue
+++ b/frontend/src/views/user/PaymentResultView.vue
@@ -150,7 +150,17 @@ onMounted(async () => {
}
}
- if (orderId) {
+ if (!order.value && !orderId && resumeToken) {
+ try {
+ const result = await paymentAPI.resolveOrderPublicByResumeToken(resumeToken)
+ order.value = result.data
+ orderId = result.data.id
+ } catch (_err: unknown) {
+ // Resume token recovery failed, continue to legacy fallback paths.
+ }
+ }
+
+ if (!order.value && orderId) {
try {
order.value = await paymentStore.pollOrderStatus(orderId)
} catch (_err: unknown) {
diff --git a/frontend/src/views/user/__tests__/PaymentResultView.spec.ts b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts
index b06217ab..d23a60d9 100644
--- a/frontend/src/views/user/__tests__/PaymentResultView.spec.ts
+++ b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts
@@ -9,6 +9,7 @@ const routerPush = vi.hoisted(() => vi.fn())
const pollOrderStatus = vi.hoisted(() => vi.fn())
const verifyOrderPublic = vi.hoisted(() => vi.fn())
const verifyOrder = vi.hoisted(() => vi.fn())
+const resolveOrderPublicByResumeToken = vi.hoisted(() => vi.fn())
vi.mock('vue-router', async () => {
const actual = await vi.importActual('vue-router')
@@ -39,6 +40,7 @@ vi.mock('@/api/payment', () => ({
paymentAPI: {
verifyOrderPublic,
verifyOrder,
+ resolveOrderPublicByResumeToken,
},
}))
@@ -67,6 +69,7 @@ describe('PaymentResultView', () => {
pollOrderStatus.mockReset()
verifyOrderPublic.mockReset()
verifyOrder.mockReset()
+ resolveOrderPublicByResumeToken.mockReset()
window.localStorage.clear()
})
@@ -129,4 +132,27 @@ describe('PaymentResultView', () => {
expect(verifyOrderPublic).toHaveBeenCalledWith('legacy-123')
expect(wrapper.text()).toContain('payment.result.success')
})
+
+ it('resolves order by resume token when local recovery snapshot is missing', async () => {
+ routeState.query = {
+ resume_token: 'resume-77',
+ }
+ resolveOrderPublicByResumeToken.mockResolvedValue({
+ data: orderFactory('PAID'),
+ })
+
+ const wrapper = mount(PaymentResultView, {
+ global: {
+ stubs: {
+ OrderStatusBadge: true,
+ },
+ },
+ })
+
+ await flushPromises()
+
+ expect(resolveOrderPublicByResumeToken).toHaveBeenCalledWith('resume-77')
+ expect(wrapper.text()).toContain('payment.result.success')
+ expect(verifyOrderPublic).not.toHaveBeenCalled()
+ })
})
--
GitLab
From 32059ae9d5fcf7024fadee26cf1e9c32dcc6da4b Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Mon, 20 Apr 2026 20:58:19 +0800
Subject: [PATCH 053/261] fix: backfill email identities on successful login
---
.../internal/service/auth_oauth_email_flow.go | 6 +++
backend/internal/service/auth_service.go | 8 ++++
.../auth_service_identity_sync_test.go | 37 +++++++++++++++++++
3 files changed, 51 insertions(+)
diff --git a/backend/internal/service/auth_oauth_email_flow.go b/backend/internal/service/auth_oauth_email_flow.go
index ca3403d4..4ab4e245 100644
--- a/backend/internal/service/auth_oauth_email_flow.go
+++ b/backend/internal/service/auth_oauth_email_flow.go
@@ -147,5 +147,11 @@ func (s *AuthService) ValidatePasswordCredentials(ctx context.Context, email, pa
// RecordSuccessfulLogin updates last-login activity after a non-standard login
// flow finishes with a real session.
func (s *AuthService) RecordSuccessfulLogin(ctx context.Context, userID int64) {
+ if s != nil && s.userRepo != nil && userID > 0 {
+ user, err := s.userRepo.GetByID(ctx, userID)
+ if err == nil {
+ s.backfillEmailIdentityOnSuccessfulLogin(ctx, user)
+ }
+ }
s.touchUserLogin(ctx, userID)
}
diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go
index 40753139..dda6df04 100644
--- a/backend/internal/service/auth_service.go
+++ b/backend/internal/service/auth_service.go
@@ -430,6 +430,7 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string
if !user.IsActive() {
return "", nil, ErrUserNotActive
}
+ s.backfillEmailIdentityOnSuccessfulLogin(ctx, user)
s.touchUserLogin(ctx, user.ID)
// 生成JWT token
@@ -802,6 +803,13 @@ func (s *AuthService) touchUserLogin(ctx context.Context, userID int64) {
}
}
+func (s *AuthService) backfillEmailIdentityOnSuccessfulLogin(ctx context.Context, user *User) {
+ if s == nil || user == nil || user.ID <= 0 {
+ return
+ }
+ s.ensureEmailAuthIdentity(ctx, user)
+}
+
func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) {
if s == nil || s.entClient == nil || user == nil || user.ID <= 0 {
return
diff --git a/backend/internal/service/auth_service_identity_sync_test.go b/backend/internal/service/auth_service_identity_sync_test.go
index 5bd2b25d..fcb4813b 100644
--- a/backend/internal/service/auth_service_identity_sync_test.go
+++ b/backend/internal/service/auth_service_identity_sync_test.go
@@ -150,4 +150,41 @@ func TestAuthServiceLoginTouchesLastLoginAt(t *testing.T) {
require.NotNil(t, storedUser.LastActiveAt)
require.True(t, storedUser.LastLoginAt.After(old))
require.True(t, storedUser.LastActiveAt.After(old))
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("login@example.com"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, user.ID, identity.UserID)
+}
+
+func TestAuthServiceRecordSuccessfulLoginBackfillsEmailIdentity(t *testing.T) {
+ svc, repo, client := newAuthServiceWithEnt(t)
+ ctx := context.Background()
+
+ user := &service.User{
+ Email: "record@example.com",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ Balance: 1,
+ Concurrency: 1,
+ }
+ require.NoError(t, user.SetPassword("password"))
+ require.NoError(t, repo.Create(ctx, user))
+
+ svc.RecordSuccessfulLogin(ctx, user.ID)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("record@example.com"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, user.ID, identity.UserID)
}
--
GitLab
From bdcd3d87e530fc1bbbec4888a51ae79c1cb0e298 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Mon, 20 Apr 2026 21:09:38 +0800
Subject: [PATCH 054/261] fix: resolve unique legacy payment providers
---
.../service/payment_order_lifecycle.go | 46 +++++---
backend/internal/service/payment_refund.go | 79 ++++++++++++-
.../service/payment_webhook_provider.go | 22 ++--
.../service/payment_webhook_provider_test.go | 109 ++++++++++++++++++
4 files changed, 220 insertions(+), 36 deletions(-)
diff --git a/backend/internal/service/payment_order_lifecycle.go b/backend/internal/service/payment_order_lifecycle.go
index f804eb8b..24d8b7a2 100644
--- a/backend/internal/service/payment_order_lifecycle.go
+++ b/backend/internal/service/payment_order_lifecycle.go
@@ -5,7 +5,6 @@ import (
"fmt"
"log/slog"
"strconv"
- "strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
@@ -237,29 +236,38 @@ func (s *PaymentService) ExpireTimedOutOrders(ctx context.Context) (int, error)
// getOrderProvider creates a provider using the order's original instance config.
// Falls back to registry lookup if instance ID is missing (legacy orders).
func (s *PaymentService) getOrderProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) {
- if o.ProviderInstanceID != nil && *o.ProviderInstanceID != "" {
- instID, err := strconv.ParseInt(*o.ProviderInstanceID, 10, 64)
- if err == nil {
- cfg, err := s.loadBalancer.GetInstanceConfig(ctx, instID)
- if err == nil {
- providerKey := strings.TrimSpace(psStringValue(o.ProviderKey))
- if providerKey == "" {
- providerKey = s.registry.GetProviderKey(o.PaymentType)
- }
- if providerKey == "" {
- providerKey = o.PaymentType
- }
- p, err := provider.CreateProvider(providerKey, *o.ProviderInstanceID, cfg)
- if err == nil {
- return p, nil
- }
- }
- }
+ inst, err := s.getOrderProviderInstance(ctx, o)
+ if err != nil {
+ return nil, fmt.Errorf("load order provider instance: %w", err)
+ }
+ if inst != nil {
+ return s.createProviderFromInstance(ctx, inst)
}
s.EnsureProviders(ctx)
return s.registry.GetProvider(o.PaymentType)
}
+func (s *PaymentService) createProviderFromInstance(ctx context.Context, inst *dbent.PaymentProviderInstance) (payment.Provider, error) {
+ if inst == nil {
+ return nil, fmt.Errorf("payment provider instance is missing")
+ }
+
+ cfg, err := s.loadBalancer.GetInstanceConfig(ctx, int64(inst.ID))
+ if err != nil {
+ return nil, fmt.Errorf("load provider instance config: %w", err)
+ }
+ if inst.PaymentMode != "" {
+ cfg["paymentMode"] = inst.PaymentMode
+ }
+
+ instID := strconv.FormatInt(int64(inst.ID), 10)
+ prov, err := provider.CreateProvider(inst.ProviderKey, instID, cfg)
+ if err != nil {
+ return nil, fmt.Errorf("create provider from instance: %w", err)
+ }
+ return prov, nil
+}
+
func psStringValue(value *string) string {
if value == nil {
return ""
diff --git a/backend/internal/service/payment_refund.go b/backend/internal/service/payment_refund.go
index c5bda763..01eecff8 100644
--- a/backend/internal/service/payment_refund.go
+++ b/backend/internal/service/payment_refund.go
@@ -12,6 +12,7 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
+ "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
"github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
@@ -19,18 +20,90 @@ import (
// --- Refund Flow ---
// getOrderProviderInstance looks up the provider instance that processed this order.
-// Returns nil, nil for legacy orders without provider_instance_id.
+// For legacy orders without provider_instance_id, it resolves only when the
+// enabled instance is uniquely identifiable from the stored order fields.
func (s *PaymentService) getOrderProviderInstance(ctx context.Context, o *dbent.PaymentOrder) (*dbent.PaymentProviderInstance, error) {
- if o.ProviderInstanceID == nil || *o.ProviderInstanceID == "" {
+ if s == nil || s.entClient == nil || o == nil {
return nil, nil
}
- instID, err := strconv.ParseInt(*o.ProviderInstanceID, 10, 64)
+
+ instIDStr := strings.TrimSpace(psStringValue(o.ProviderInstanceID))
+ if instIDStr == "" {
+ return s.resolveUniqueLegacyOrderProviderInstance(ctx, o)
+ }
+
+ instID, err := strconv.ParseInt(instIDStr, 10, 64)
if err != nil {
return nil, nil
}
return s.entClient.PaymentProviderInstance.Get(ctx, instID)
}
+func (s *PaymentService) resolveUniqueLegacyOrderProviderInstance(ctx context.Context, o *dbent.PaymentOrder) (*dbent.PaymentProviderInstance, error) {
+ providerKey := strings.TrimSpace(psStringValue(o.ProviderKey))
+ if providerKey != "" {
+ instances, err := s.entClient.PaymentProviderInstance.Query().
+ Where(
+ paymentproviderinstance.EnabledEQ(true),
+ paymentproviderinstance.ProviderKeyEQ(providerKey),
+ ).
+ All(ctx)
+ if err != nil {
+ return nil, err
+ }
+ if len(instances) == 1 {
+ return instances[0], nil
+ }
+ return nil, nil
+ }
+
+ paymentType := payment.GetBasePaymentType(strings.TrimSpace(o.PaymentType))
+ if paymentType == "" {
+ return nil, nil
+ }
+
+ instances, err := s.entClient.PaymentProviderInstance.Query().
+ Where(paymentproviderinstance.EnabledEQ(true)).
+ All(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ var matched []*dbent.PaymentProviderInstance
+ for _, inst := range instances {
+ if psLegacyOrderMatchesInstance(paymentType, inst) {
+ matched = append(matched, inst)
+ }
+ }
+ if len(matched) == 1 {
+ return matched[0], nil
+ }
+ return nil, nil
+}
+
+func psLegacyOrderMatchesInstance(orderPaymentType string, inst *dbent.PaymentProviderInstance) bool {
+ if inst == nil {
+ return false
+ }
+
+ baseType := payment.GetBasePaymentType(strings.TrimSpace(orderPaymentType))
+ instanceProviderKey := strings.TrimSpace(inst.ProviderKey)
+ if baseType == "" {
+ return false
+ }
+
+ if baseType == payment.TypeStripe {
+ return instanceProviderKey == payment.TypeStripe
+ }
+ if instanceProviderKey == payment.TypeStripe {
+ return false
+ }
+ if instanceProviderKey == baseType {
+ return true
+ }
+ return payment.InstanceSupportsType(inst.SupportedTypes, baseType)
+}
+
func (s *PaymentService) RequestRefund(ctx context.Context, oid, uid int64, reason string) error {
o, err := s.validateRefundRequest(ctx, oid, uid)
if err != nil {
diff --git a/backend/internal/service/payment_webhook_provider.go b/backend/internal/service/payment_webhook_provider.go
index a877db2b..289d63ed 100644
--- a/backend/internal/service/payment_webhook_provider.go
+++ b/backend/internal/service/payment_webhook_provider.go
@@ -4,14 +4,12 @@ import (
"context"
"fmt"
"log/slog"
- "strconv"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
"github.com/Wei-Shaw/sub2api/internal/payment"
- "github.com/Wei-Shaw/sub2api/internal/payment/provider"
)
// GetWebhookProvider returns the provider instance that should verify a webhook.
@@ -24,6 +22,13 @@ func (s *PaymentService) GetWebhookProvider(ctx context.Context, providerKey, ou
if psHasPinnedProviderInstance(order) {
return s.getPinnedOrderProvider(ctx, order)
}
+ inst, err := s.getOrderProviderInstance(ctx, order)
+ if err != nil {
+ return nil, fmt.Errorf("load order provider instance: %w", err)
+ }
+ if inst != nil {
+ return s.createProviderFromInstance(ctx, inst)
+ }
if !s.webhookRegistryFallbackAllowed(ctx, providerKey) {
return nil, fmt.Errorf("webhook provider fallback is ambiguous for %s", providerKey)
}
@@ -48,18 +53,7 @@ func (s *PaymentService) getPinnedOrderProvider(ctx context.Context, o *dbent.Pa
if inst == nil {
return nil, fmt.Errorf("order %d provider instance is missing", o.ID)
}
-
- instID := strconv.FormatInt(int64(inst.ID), 10)
- cfg, err := s.loadBalancer.GetInstanceConfig(ctx, int64(inst.ID))
- if err != nil {
- return nil, fmt.Errorf("load provider instance config: %w", err)
- }
-
- prov, err := provider.CreateProvider(inst.ProviderKey, instID, cfg)
- if err != nil {
- return nil, fmt.Errorf("create pinned provider: %w", err)
- }
- return prov, nil
+ return s.createProviderFromInstance(ctx, inst)
}
func (s *PaymentService) webhookRegistryFallbackAllowed(ctx context.Context, providerKey string) bool {
diff --git a/backend/internal/service/payment_webhook_provider_test.go b/backend/internal/service/payment_webhook_provider_test.go
index 85c296de..33e4186d 100644
--- a/backend/internal/service/payment_webhook_provider_test.go
+++ b/backend/internal/service/payment_webhook_provider_test.go
@@ -4,13 +4,17 @@ package service
import (
"context"
+ "encoding/json"
"testing"
"time"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/stretchr/testify/require"
)
+const webhookProviderTestEncryptionKey = "0123456789abcdef0123456789abcdef"
+
type webhookProviderTestDouble struct {
key string
types []payment.PaymentType
@@ -32,6 +36,111 @@ func (p webhookProviderTestDouble) Refund(context.Context, payment.RefundRequest
panic("unexpected call")
}
+func encryptWebhookProviderConfig(t *testing.T, config map[string]string) string {
+ t.Helper()
+
+ data, err := json.Marshal(config)
+ require.NoError(t, err)
+
+ encrypted, err := payment.Encrypt(string(data), []byte(webhookProviderTestEncryptionKey))
+ require.NoError(t, err)
+ return encrypted
+}
+
+func newWebhookProviderTestLoadBalancer(client *dbent.Client) payment.LoadBalancer {
+ return payment.NewDefaultLoadBalancer(client, []byte(webhookProviderTestEncryptionKey))
+}
+
+func TestGetOrderProviderInstanceResolvesUniqueLegacyProviderKey(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ inst, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeStripe).
+ SetName("stripe-a").
+ SetConfig(encryptWebhookProviderConfig(t, map[string]string{"secretKey": "sk_test_legacy_provider_key"})).
+ SetSupportedTypes("stripe").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ providerKey := payment.TypeStripe
+ order := &dbent.PaymentOrder{
+ PaymentType: payment.TypeStripe,
+ ProviderKey: &providerKey,
+ }
+
+ svc := &PaymentService{
+ entClient: client,
+ loadBalancer: newWebhookProviderTestLoadBalancer(client),
+ }
+
+ got, err := svc.getOrderProviderInstance(ctx, order)
+ require.NoError(t, err)
+ require.NotNil(t, got)
+ require.Equal(t, inst.ID, got.ID)
+}
+
+func TestGetOrderProviderInstanceResolvesUniqueLegacyPaymentType(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ inst, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeWxpay).
+ SetName("wxpay-a").
+ SetConfig("{}").
+ SetSupportedTypes("wxpay").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ order := &dbent.PaymentOrder{
+ PaymentType: payment.TypeWxpayDirect,
+ }
+
+ svc := &PaymentService{
+ entClient: client,
+ loadBalancer: newWebhookProviderTestLoadBalancer(client),
+ }
+
+ got, err := svc.getOrderProviderInstance(ctx, order)
+ require.NoError(t, err)
+ require.NotNil(t, got)
+ require.Equal(t, inst.ID, got.ID)
+}
+
+func TestGetOrderProviderInstanceLeavesAmbiguousLegacyOrderUnresolved(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ _, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeEasyPay).
+ SetName("easypay-a").
+ SetConfig("{}").
+ SetSupportedTypes("wxpay").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+ _, err = client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeWxpay).
+ SetName("wxpay-a").
+ SetConfig("{}").
+ SetSupportedTypes("wxpay").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ order := &dbent.PaymentOrder{
+ PaymentType: payment.TypeWxpay,
+ }
+
+ svc := &PaymentService{
+ entClient: client,
+ loadBalancer: newWebhookProviderTestLoadBalancer(client),
+ }
+
+ got, err := svc.getOrderProviderInstance(ctx, order)
+ require.NoError(t, err)
+ require.Nil(t, got)
+}
+
func TestGetWebhookProviderRejectsAmbiguousRegistryFallback(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
--
GitLab
From 5adefb466b9c7a1cd02e2c9f1c0221c08d38e5bf Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Mon, 20 Apr 2026 21:24:33 +0800
Subject: [PATCH 055/261] fix: finalize oauth identity bindings
---
.../handler/auth_oauth_pending_flow.go | 152 +++++++++++++++++-
.../handler/auth_oauth_pending_flow_test.go | 11 +-
backend/internal/handler/auth_wechat_oauth.go | 96 +++++++++++
.../handler/auth_wechat_oauth_test.go | 124 ++++++++++++++
4 files changed, 376 insertions(+), 7 deletions(-)
diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go
index 6d6564e8..21ed2bc6 100644
--- a/backend/internal/handler/auth_oauth_pending_flow.go
+++ b/backend/internal/handler/auth_oauth_pending_flow.go
@@ -10,6 +10,7 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
@@ -309,6 +310,14 @@ func cloneOAuthMetadata(values map[string]any) map[string]any {
return cloned
}
+func mergeOAuthMetadata(base map[string]any, overlay map[string]any) map[string]any {
+ merged := cloneOAuthMetadata(base)
+ for key, value := range overlay {
+ merged[key] = value
+ }
+ return merged
+}
+
func normalizeAdoptedOAuthDisplayName(value string) string {
value = strings.TrimSpace(value)
if len([]rune(value)) > 100 {
@@ -558,6 +567,10 @@ func oauthIdentityIssuer(session *dbent.PendingAuthSession) *string {
}
func ensurePendingOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, session *dbent.PendingAuthSession, userID int64) (*dbent.AuthIdentity, error) {
+ if session != nil && strings.EqualFold(strings.TrimSpace(session.ProviderType), "wechat") {
+ return ensurePendingWeChatOAuthIdentityForUser(ctx, tx, session, userID)
+ }
+
client := tx.Client()
identity, err := client.AuthIdentity.Query().
Where(
@@ -588,14 +601,149 @@ func ensurePendingOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, sessio
return create.Save(ctx)
}
+func ensurePendingWeChatOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, session *dbent.PendingAuthSession, userID int64) (*dbent.AuthIdentity, error) {
+ client := tx.Client()
+ providerType := strings.TrimSpace(session.ProviderType)
+ providerKey := strings.TrimSpace(session.ProviderKey)
+ providerSubject := strings.TrimSpace(session.ProviderSubject)
+ channel := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel"))
+ channelAppID := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel_app_id"))
+ channelSubject := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel_subject"))
+ metadata := cloneOAuthMetadata(session.UpstreamIdentityClaims)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ(providerType),
+ authidentity.ProviderKeyEQ(providerKey),
+ authidentity.ProviderSubjectEQ(providerSubject),
+ ).
+ Only(ctx)
+ if err != nil && !dbent.IsNotFound(err) {
+ return nil, err
+ }
+ if identity != nil && identity.UserID != userID {
+ return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
+ }
+
+ var legacyOpenIDIdentity *dbent.AuthIdentity
+ if channelSubject != "" && channelSubject != providerSubject {
+ legacyOpenIDIdentity, err = client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ(providerType),
+ authidentity.ProviderKeyEQ(providerKey),
+ authidentity.ProviderSubjectEQ(channelSubject),
+ ).
+ Only(ctx)
+ if err != nil && !dbent.IsNotFound(err) {
+ return nil, err
+ }
+ if legacyOpenIDIdentity != nil && legacyOpenIDIdentity.UserID != userID {
+ return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
+ }
+ }
+
+ switch {
+ case identity != nil:
+ update := client.AuthIdentity.UpdateOneID(identity.ID).
+ SetMetadata(mergeOAuthMetadata(identity.Metadata, metadata))
+ if issuer := oauthIdentityIssuer(session); issuer != nil {
+ update = update.SetIssuer(strings.TrimSpace(*issuer))
+ }
+ identity, err = update.Save(ctx)
+ if err != nil {
+ return nil, err
+ }
+ case legacyOpenIDIdentity != nil:
+ update := client.AuthIdentity.UpdateOneID(legacyOpenIDIdentity.ID).
+ SetProviderSubject(providerSubject).
+ SetMetadata(mergeOAuthMetadata(legacyOpenIDIdentity.Metadata, metadata))
+ if issuer := oauthIdentityIssuer(session); issuer != nil {
+ update = update.SetIssuer(strings.TrimSpace(*issuer))
+ }
+ identity, err = update.Save(ctx)
+ if err != nil {
+ return nil, err
+ }
+ default:
+ create := client.AuthIdentity.Create().
+ SetUserID(userID).
+ SetProviderType(providerType).
+ SetProviderKey(providerKey).
+ SetProviderSubject(providerSubject).
+ SetMetadata(metadata)
+ if issuer := oauthIdentityIssuer(session); issuer != nil {
+ create = create.SetIssuer(strings.TrimSpace(*issuer))
+ }
+ identity, err = create.Save(ctx)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ if channel == "" || channelAppID == "" || channelSubject == "" {
+ return identity, nil
+ }
+
+ channelRecord, err := client.AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ(providerType),
+ authidentitychannel.ProviderKeyEQ(providerKey),
+ authidentitychannel.ChannelEQ(channel),
+ authidentitychannel.ChannelAppIDEQ(channelAppID),
+ authidentitychannel.ChannelSubjectEQ(channelSubject),
+ ).
+ WithIdentity().
+ Only(ctx)
+ if err != nil && !dbent.IsNotFound(err) {
+ return nil, err
+ }
+ if channelRecord != nil && channelRecord.Edges.Identity != nil && channelRecord.Edges.Identity.UserID != userID {
+ return nil, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user")
+ }
+
+ channelMetadata := mergeOAuthMetadata(channelRecordMetadata(channelRecord), metadata)
+ if channelRecord == nil {
+ if _, err := client.AuthIdentityChannel.Create().
+ SetIdentityID(identity.ID).
+ SetProviderType(providerType).
+ SetProviderKey(providerKey).
+ SetChannel(channel).
+ SetChannelAppID(channelAppID).
+ SetChannelSubject(channelSubject).
+ SetMetadata(channelMetadata).
+ Save(ctx); err != nil {
+ return nil, err
+ }
+ return identity, nil
+ }
+
+ _, err = client.AuthIdentityChannel.UpdateOneID(channelRecord.ID).
+ SetIdentityID(identity.ID).
+ SetMetadata(channelMetadata).
+ Save(ctx)
+ if err != nil {
+ return nil, err
+ }
+ return identity, nil
+}
+
+func channelRecordMetadata(channel *dbent.AuthIdentityChannel) map[string]any {
+ if channel == nil {
+ return map[string]any{}
+ }
+ return cloneOAuthMetadata(channel.Metadata)
+}
+
func shouldBindPendingOAuthIdentity(session *dbent.PendingAuthSession, decision *dbent.IdentityAdoptionDecision) bool {
if session == nil || decision == nil {
return false
}
- if strings.EqualFold(strings.TrimSpace(session.Intent), "bind_current_user") {
+ switch strings.ToLower(strings.TrimSpace(session.Intent)) {
+ case "bind_current_user", "login", "adopt_existing_user_by_email":
return true
+ default:
+ return decision.AdoptDisplayName || decision.AdoptAvatar
}
- return decision.AdoptDisplayName || decision.AdoptAvatar
}
func applyPendingOAuthBinding(
diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go
index ae506e52..89accd60 100644
--- a/backend/internal/handler/auth_oauth_pending_flow_test.go
+++ b/backend/internal/handler/auth_oauth_pending_flow_test.go
@@ -372,7 +372,7 @@ func TestExchangePendingOAuthCompletionBindCurrentUserOwnershipConflict(t *testi
require.Nil(t, storedSession.ConsumedAt)
}
-func TestExchangePendingOAuthCompletionLoginFalseFalseDoesNotBindIdentity(t *testing.T) {
+func TestExchangePendingOAuthCompletionLoginFalseFalseBindsIdentityWithoutAdoption(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
@@ -420,21 +420,22 @@ func TestExchangePendingOAuthCompletionLoginFalseFalseDoesNotBindIdentity(t *tes
require.Equal(t, http.StatusOK, recorder.Code)
- identityCount, err := client.AuthIdentity.Query().
+ identity, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("linuxdo"),
authidentity.ProviderKeyEQ("linuxdo"),
authidentity.ProviderSubjectEQ("login-false-123"),
).
- Count(ctx)
+ Only(ctx)
require.NoError(t, err)
- require.Zero(t, identityCount)
+ require.Equal(t, userEntity.ID, identity.UserID)
decision, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
Only(ctx)
require.NoError(t, err)
- require.Nil(t, decision.IdentityID)
+ require.NotNil(t, decision.IdentityID)
+ require.Equal(t, identity.ID, *decision.IdentityID)
require.False(t, decision.AdoptDisplayName)
require.False(t, decision.AdoptAvatar)
diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go
index f0755f1f..b6d47670 100644
--- a/backend/internal/handler/auth_wechat_oauth.go
+++ b/backend/internal/handler/auth_wechat_oauth.go
@@ -242,7 +242,18 @@ func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) {
redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
return
}
+ if existingIdentityUser == nil {
+ existingIdentityUser, err = h.findWeChatUserByLegacyOpenID(c.Request.Context(), identityRef, cfg, openid)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+ }
if existingIdentityUser != nil {
+ if err := h.ensureWeChatRuntimeIdentityBinding(c.Request.Context(), existingIdentityUser.ID, identityRef, upstreamClaims); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), existingIdentityUser.Email, username, "")
if err != nil {
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
@@ -511,6 +522,91 @@ func (h *AuthHandler) ensureWeChatBindOwnership(
return nil
}
+func (h *AuthHandler) findWeChatUserByLegacyOpenID(
+ ctx context.Context,
+ identity service.PendingAuthIdentityKey,
+ cfg wechatOAuthConfig,
+ openid string,
+) (*dbent.User, error) {
+ client := h.entClient()
+ if client == nil {
+ return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+
+ openid = strings.TrimSpace(openid)
+ channel := strings.TrimSpace(cfg.mode)
+ channelAppID := strings.TrimSpace(cfg.appID)
+ if openid != "" && channel != "" && channelAppID != "" {
+ record, err := client.AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ(strings.TrimSpace(identity.ProviderType)),
+ authidentitychannel.ProviderKeyEQ(strings.TrimSpace(identity.ProviderKey)),
+ authidentitychannel.ChannelEQ(channel),
+ authidentitychannel.ChannelAppIDEQ(channelAppID),
+ authidentitychannel.ChannelSubjectEQ(openid),
+ ).
+ WithIdentity(func(q *dbent.AuthIdentityQuery) {
+ q.WithUser()
+ }).
+ Only(ctx)
+ if err != nil && !dbent.IsNotFound(err) {
+ return nil, infraerrors.InternalServer("AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED", "failed to inspect auth identity channel ownership").WithCause(err)
+ }
+ if record != nil && record.Edges.Identity != nil && record.Edges.Identity.Edges.User != nil {
+ return record.Edges.Identity.Edges.User, nil
+ }
+ }
+
+ if openid == "" {
+ return nil, nil
+ }
+
+ record, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ(strings.TrimSpace(identity.ProviderType)),
+ authidentity.ProviderKeyEQ(strings.TrimSpace(identity.ProviderKey)),
+ authidentity.ProviderSubjectEQ(openid),
+ ).
+ WithUser().
+ Only(ctx)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, nil
+ }
+ return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
+ }
+ return record.Edges.User, nil
+}
+
+func (h *AuthHandler) ensureWeChatRuntimeIdentityBinding(
+ ctx context.Context,
+ userID int64,
+ identity service.PendingAuthIdentityKey,
+ upstreamClaims map[string]any,
+) error {
+ client := h.entClient()
+ if client == nil {
+ return infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+
+ tx, err := client.Tx(ctx)
+ if err != nil {
+ return infraerrors.InternalServer("AUTH_IDENTITY_BIND_FAILED", "failed to begin wechat identity repair transaction").WithCause(err)
+ }
+ defer func() { _ = tx.Rollback() }()
+
+ _, err = ensurePendingOAuthIdentityForUser(dbent.NewTxContext(ctx, tx), tx, &dbent.PendingAuthSession{
+ ProviderType: strings.TrimSpace(identity.ProviderType),
+ ProviderKey: strings.TrimSpace(identity.ProviderKey),
+ ProviderSubject: strings.TrimSpace(identity.ProviderSubject),
+ UpstreamIdentityClaims: cloneOAuthMetadata(upstreamClaims),
+ }, userID)
+ if err != nil {
+ return err
+ }
+ return tx.Commit()
+}
+
func (h *AuthHandler) getWeChatOAuthConfig(ctx context.Context, rawMode string, c *gin.Context) (wechatOAuthConfig, error) {
mode, err := resolveWeChatOAuthMode(rawMode, c)
if err != nil {
diff --git a/backend/internal/handler/auth_wechat_oauth_test.go b/backend/internal/handler/auth_wechat_oauth_test.go
index 1ff80e1b..dd022fb9 100644
--- a/backend/internal/handler/auth_wechat_oauth_test.go
+++ b/backend/internal/handler/auth_wechat_oauth_test.go
@@ -15,6 +15,7 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
@@ -563,6 +564,19 @@ func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing
require.Equal(t, "WeChat Display", identity.Metadata["display_name"])
require.Equal(t, "https://cdn.example/wechat.png", identity.Metadata["avatar_url"])
+ channel, err := client.AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ("wechat"),
+ authidentitychannel.ProviderKeyEQ("wechat-main"),
+ authidentitychannel.ChannelEQ("open"),
+ authidentitychannel.ChannelAppIDEQ("wx-open-app"),
+ authidentitychannel.ChannelSubjectEQ("openid-123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, identity.ID, channel.IdentityID)
+ require.Equal(t, "union-456", channel.Metadata["unionid"])
+
decision, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(pendingSession.ID)).
Only(ctx)
@@ -579,6 +593,116 @@ func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing
require.NotNil(t, consumed.ConsumedAt)
}
+func TestWeChatOAuthCallbackRepairsLegacyOpenIDOnlyIdentity(t *testing.T) {
+ t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
+ t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
+ t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback")
+
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ originalUserInfoURL := wechatOAuthUserInfoURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ wechatOAuthUserInfoURL = originalUserInfoURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
+ case strings.Contains(r.URL.Path, "/sns/userinfo"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Legacy WeChat","headimgurl":"https://cdn.example/legacy.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+ wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
+
+ handler, client := newWeChatOAuthTestHandler(t, false)
+ defer client.Close()
+
+ ctx := context.Background()
+ legacyUser, err := client.User.Create().
+ SetEmail("legacy@example.com").
+ SetUsername("legacy-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ legacyIdentity, err := client.AuthIdentity.Create().
+ SetUserID(legacyUser.ID).
+ SetProviderType("wechat").
+ SetProviderKey(wechatOAuthProviderKey).
+ SetProviderSubject("openid-123").
+ SetMetadata(map[string]any{"openid": "openid-123"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
+ req.Host = "api.example.com"
+ req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
+ req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
+ req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ c.Request = req
+
+ handler.WeChatOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, session.TargetUserID)
+ require.Equal(t, legacyUser.ID, *session.TargetUserID)
+ require.Equal(t, legacyUser.Email, session.ResolvedEmail)
+
+ repairedIdentity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("wechat"),
+ authidentity.ProviderKeyEQ(wechatOAuthProviderKey),
+ authidentity.ProviderSubjectEQ("union-456"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, legacyIdentity.ID, repairedIdentity.ID)
+ require.Equal(t, legacyUser.ID, repairedIdentity.UserID)
+
+ openIDIdentityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("wechat"),
+ authidentity.ProviderKeyEQ(wechatOAuthProviderKey),
+ authidentity.ProviderSubjectEQ("openid-123"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, openIDIdentityCount)
+
+ channel, err := client.AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ("wechat"),
+ authidentitychannel.ProviderKeyEQ(wechatOAuthProviderKey),
+ authidentitychannel.ChannelEQ("open"),
+ authidentitychannel.ChannelAppIDEQ("wx-open-app"),
+ authidentitychannel.ChannelSubjectEQ("openid-123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, repairedIdentity.ID, channel.IdentityID)
+}
+
func newWeChatOAuthTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) {
t.Helper()
--
GitLab
From f65429145e57c7bb61844fc978a3ef92e53710ef Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Mon, 20 Apr 2026 21:31:05 +0800
Subject: [PATCH 056/261] fix: route legacy linuxdo users to account binding
---
.../internal/handler/auth_linuxdo_oauth.go | 59 ++++++++++++++
.../handler/auth_linuxdo_oauth_test.go | 76 +++++++++++++++++++
2 files changed, 135 insertions(+)
diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go
index c3ec3804..a0760a3b 100644
--- a/backend/internal/handler/auth_linuxdo_oauth.go
+++ b/backend/internal/handler/auth_linuxdo_oauth.go
@@ -15,6 +15,8 @@ import (
"time"
"unicode/utf8"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
@@ -237,6 +239,7 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
redirectOAuthError(c, frontendCallback, "userinfo_failed", "failed to fetch user info", "")
return
}
+ compatEmail := strings.TrimSpace(email)
// 安全考虑:不要把第三方返回的 email 直接映射到本地账号(可能与本地邮箱用户冲突导致账号被接管)。
// 统一使用基于 subject 的稳定合成邮箱来做账号绑定。
@@ -255,6 +258,9 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
"suggested_display_name": displayName,
"suggested_avatar_url": avatarURL,
}
+ if compatEmail != "" && !strings.EqualFold(strings.TrimSpace(compatEmail), strings.TrimSpace(email)) {
+ upstreamClaims["compat_email"] = compatEmail
+ }
if intent == oauthIntentBindCurrentUser {
targetUserID, err := h.readOAuthBindUserIDFromCookie(c, linuxDoOAuthBindUserCookieName)
if err != nil {
@@ -314,6 +320,33 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
return
}
+ compatEmailUser, err := h.findLinuxDoCompatEmailUser(c.Request.Context(), compatEmail)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+ if compatEmailUser != nil {
+ if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: "adopt_existing_user_by_email",
+ Identity: identityKey,
+ TargetUserID: &compatEmailUser.ID,
+ ResolvedEmail: compatEmailUser.Email,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
+ CompletionResponse: map[string]any{
+ "redirect": redirectTo,
+ "step": "bind_login_required",
+ "email": compatEmailUser.Email,
+ },
+ }); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+ return
+ }
+
if h.isForceEmailOnThirdPartySignup(c.Request.Context()) {
if err := h.createOAuthEmailRequiredPendingSession(c, identityKey, redirectTo, browserSessionKey, upstreamClaims); err != nil {
redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
@@ -372,6 +405,32 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
redirectToFrontendCallback(c, frontendCallback)
}
+func (h *AuthHandler) findLinuxDoCompatEmailUser(ctx context.Context, email string) (*dbent.User, error) {
+ client := h.entClient()
+ if client == nil {
+ return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+
+ email = strings.TrimSpace(strings.ToLower(email))
+ if email == "" ||
+ strings.HasSuffix(email, service.LinuxDoConnectSyntheticEmailDomain) ||
+ strings.HasSuffix(email, service.OIDCConnectSyntheticEmailDomain) ||
+ strings.HasSuffix(email, service.WeChatConnectSyntheticEmailDomain) {
+ return nil, nil
+ }
+
+ userEntity, err := client.User.Query().
+ Where(dbuser.EmailEqualFold(email)).
+ Only(ctx)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, nil
+ }
+ return nil, infraerrors.InternalServer("COMPAT_EMAIL_LOOKUP_FAILED", "failed to look up compat email user").WithCause(err)
+ }
+ return userEntity, nil
+}
+
type completeLinuxDoOAuthRequest struct {
InvitationCode string `json:"invitation_code" binding:"required"`
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
diff --git a/backend/internal/handler/auth_linuxdo_oauth_test.go b/backend/internal/handler/auth_linuxdo_oauth_test.go
index 765779b5..fb57e570 100644
--- a/backend/internal/handler/auth_linuxdo_oauth_test.go
+++ b/backend/internal/handler/auth_linuxdo_oauth_test.go
@@ -300,6 +300,82 @@ func TestLinuxDoOAuthCallbackCreatesLoginPendingSessionForExistingUser(t *testin
require.Nil(t, completion["error"])
}
+func TestLinuxDoOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *testing.T) {
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/token":
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`))
+ case "/userinfo":
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"id":"321","email":"legacy@example.com","username":"linuxdo_user","name":"LinuxDo Display","avatar_url":"https://cdn.example/linuxdo.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+
+ handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{
+ Enabled: true,
+ ClientID: "linuxdo-client",
+ ClientSecret: "linuxdo-secret",
+ AuthorizeURL: upstream.URL + "/authorize",
+ TokenURL: upstream.URL + "/token",
+ UserInfoURL: upstream.URL + "/userinfo",
+ Scopes: "read",
+ RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
+ FrontendRedirectURL: "/auth/linuxdo/callback",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: true,
+ })
+ defer client.Close()
+
+ ctx := context.Background()
+ existingUser, err := client.User.Create().
+ SetEmail("legacy@example.com").
+ SetUsername("legacy-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=code-compat&state=state-compat", nil)
+ req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-compat"))
+ req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/dashboard"))
+ req.AddCookie(encodedCookie(linuxDoOAuthVerifierCookie, "verifier-compat"))
+ req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentLogin))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-compat"))
+ c.Request = req
+
+ handler.LinuxDoOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/linuxdo/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "adopt_existing_user_by_email", session.Intent)
+ require.NotNil(t, session.TargetUserID)
+ require.Equal(t, existingUser.ID, *session.TargetUserID)
+ require.Equal(t, existingUser.Email, session.ResolvedEmail)
+ require.Equal(t, "legacy@example.com", session.UpstreamIdentityClaims["compat_email"])
+
+ completion := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
+ require.Equal(t, "/dashboard", completion["redirect"])
+ require.Equal(t, "bind_login_required", completion["step"])
+ require.Equal(t, existingUser.Email, completion["email"])
+ _, hasAccessToken := completion["access_token"]
+ require.False(t, hasAccessToken)
+}
+
func TestLinuxDoOAuthCallbackCreatesInvitationPendingSessionWhenSignupRequiresInvite(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
--
GitLab
From 422f60a14513b8c9df33bdf8209d82fb6888d7b9 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Mon, 20 Apr 2026 21:42:35 +0800
Subject: [PATCH 057/261] fix: normalize legacy wechat auth identity keys
---
.../handler/auth_oauth_pending_flow.go | 109 ++++++++--
backend/internal/handler/auth_wechat_oauth.go | 122 ++++++++---
.../handler/auth_wechat_oauth_test.go | 192 ++++++++++++++++++
...3_normalize_legacy_wechat_provider_key.sql | 89 ++++++++
4 files changed, 464 insertions(+), 48 deletions(-)
create mode 100644 backend/migrations/113_normalize_legacy_wechat_provider_key.sql
diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go
index 21ed2bc6..fd35e4e5 100644
--- a/backend/internal/handler/auth_oauth_pending_flow.go
+++ b/backend/internal/handler/auth_oauth_pending_flow.go
@@ -606,39 +606,42 @@ func ensurePendingWeChatOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx,
providerType := strings.TrimSpace(session.ProviderType)
providerKey := strings.TrimSpace(session.ProviderKey)
providerSubject := strings.TrimSpace(session.ProviderSubject)
+ providerKeys := wechatCompatibleProviderKeys(providerKey)
channel := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel"))
channelAppID := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel_app_id"))
channelSubject := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel_subject"))
metadata := cloneOAuthMetadata(session.UpstreamIdentityClaims)
- identity, err := client.AuthIdentity.Query().
+ identityRecords, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ(providerType),
- authidentity.ProviderKeyEQ(providerKey),
+ authidentity.ProviderKeyIn(providerKeys...),
authidentity.ProviderSubjectEQ(providerSubject),
).
- Only(ctx)
- if err != nil && !dbent.IsNotFound(err) {
+ All(ctx)
+ if err != nil {
return nil, err
}
- if identity != nil && identity.UserID != userID {
- return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
+ identity, hasCanonicalKey, err := chooseWeChatIdentityForUser(identityRecords, userID, providerKey)
+ if err != nil {
+ return nil, err
}
var legacyOpenIDIdentity *dbent.AuthIdentity
if channelSubject != "" && channelSubject != providerSubject {
- legacyOpenIDIdentity, err = client.AuthIdentity.Query().
+ legacyOpenIDRecords, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ(providerType),
- authidentity.ProviderKeyEQ(providerKey),
+ authidentity.ProviderKeyIn(providerKeys...),
authidentity.ProviderSubjectEQ(channelSubject),
).
- Only(ctx)
- if err != nil && !dbent.IsNotFound(err) {
+ All(ctx)
+ if err != nil {
return nil, err
}
- if legacyOpenIDIdentity != nil && legacyOpenIDIdentity.UserID != userID {
- return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
+ legacyOpenIDIdentity, _, err = chooseWeChatIdentityForUser(legacyOpenIDRecords, userID, providerKey)
+ if err != nil {
+ return nil, err
}
}
@@ -646,6 +649,9 @@ func ensurePendingWeChatOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx,
case identity != nil:
update := client.AuthIdentity.UpdateOneID(identity.ID).
SetMetadata(mergeOAuthMetadata(identity.Metadata, metadata))
+ if !strings.EqualFold(strings.TrimSpace(identity.ProviderKey), providerKey) && !hasCanonicalKey {
+ update = update.SetProviderKey(providerKey)
+ }
if issuer := oauthIdentityIssuer(session); issuer != nil {
update = update.SetIssuer(strings.TrimSpace(*issuer))
}
@@ -655,6 +661,7 @@ func ensurePendingWeChatOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx,
}
case legacyOpenIDIdentity != nil:
update := client.AuthIdentity.UpdateOneID(legacyOpenIDIdentity.ID).
+ SetProviderKey(providerKey).
SetProviderSubject(providerSubject).
SetMetadata(mergeOAuthMetadata(legacyOpenIDIdentity.Metadata, metadata))
if issuer := oauthIdentityIssuer(session); issuer != nil {
@@ -684,21 +691,22 @@ func ensurePendingWeChatOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx,
return identity, nil
}
- channelRecord, err := client.AuthIdentityChannel.Query().
+ channelRecords, err := client.AuthIdentityChannel.Query().
Where(
authidentitychannel.ProviderTypeEQ(providerType),
- authidentitychannel.ProviderKeyEQ(providerKey),
+ authidentitychannel.ProviderKeyIn(providerKeys...),
authidentitychannel.ChannelEQ(channel),
authidentitychannel.ChannelAppIDEQ(channelAppID),
authidentitychannel.ChannelSubjectEQ(channelSubject),
).
WithIdentity().
- Only(ctx)
- if err != nil && !dbent.IsNotFound(err) {
+ All(ctx)
+ if err != nil {
return nil, err
}
- if channelRecord != nil && channelRecord.Edges.Identity != nil && channelRecord.Edges.Identity.UserID != userID {
- return nil, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user")
+ channelRecord, hasCanonicalChannelKey, err := chooseWeChatChannelForUser(channelRecords, userID, providerKey)
+ if err != nil {
+ return nil, err
}
channelMetadata := mergeOAuthMetadata(channelRecordMetadata(channelRecord), metadata)
@@ -717,16 +725,75 @@ func ensurePendingWeChatOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx,
return identity, nil
}
- _, err = client.AuthIdentityChannel.UpdateOneID(channelRecord.ID).
+ updateChannel := client.AuthIdentityChannel.UpdateOneID(channelRecord.ID).
SetIdentityID(identity.ID).
- SetMetadata(channelMetadata).
- Save(ctx)
+ SetMetadata(channelMetadata)
+ if !strings.EqualFold(strings.TrimSpace(channelRecord.ProviderKey), providerKey) && !hasCanonicalChannelKey {
+ updateChannel = updateChannel.SetProviderKey(providerKey)
+ }
+ _, err = updateChannel.Save(ctx)
if err != nil {
return nil, err
}
return identity, nil
}
+func chooseWeChatIdentityForUser(records []*dbent.AuthIdentity, userID int64, preferredProviderKey string) (*dbent.AuthIdentity, bool, error) {
+ var preferred *dbent.AuthIdentity
+ var fallback *dbent.AuthIdentity
+ hasCanonicalKey := false
+ for _, record := range records {
+ if record == nil {
+ continue
+ }
+ if record.UserID != userID {
+ return nil, false, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
+ }
+ if strings.EqualFold(strings.TrimSpace(record.ProviderKey), preferredProviderKey) {
+ hasCanonicalKey = true
+ if preferred == nil {
+ preferred = record
+ }
+ continue
+ }
+ if fallback == nil {
+ fallback = record
+ }
+ }
+ if preferred != nil {
+ return preferred, hasCanonicalKey, nil
+ }
+ return fallback, hasCanonicalKey, nil
+}
+
+func chooseWeChatChannelForUser(records []*dbent.AuthIdentityChannel, userID int64, preferredProviderKey string) (*dbent.AuthIdentityChannel, bool, error) {
+ var preferred *dbent.AuthIdentityChannel
+ var fallback *dbent.AuthIdentityChannel
+ hasCanonicalKey := false
+ for _, record := range records {
+ if record == nil {
+ continue
+ }
+ if record.Edges.Identity != nil && record.Edges.Identity.UserID != userID {
+ return nil, false, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user")
+ }
+ if strings.EqualFold(strings.TrimSpace(record.ProviderKey), preferredProviderKey) {
+ hasCanonicalKey = true
+ if preferred == nil {
+ preferred = record
+ }
+ continue
+ }
+ if fallback == nil {
+ fallback = record
+ }
+ }
+ if preferred != nil {
+ return preferred, hasCanonicalKey, nil
+ }
+ return fallback, hasCanonicalKey, nil
+}
+
func channelRecordMetadata(channel *dbent.AuthIdentityChannel) map[string]any {
if channel == nil {
return map[string]any{}
diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go
index b6d47670..95993dfc 100644
--- a/backend/internal/handler/auth_wechat_oauth.go
+++ b/backend/internal/handler/auth_wechat_oauth.go
@@ -34,6 +34,7 @@ const (
wechatOAuthDefaultRedirectTo = "/dashboard"
wechatOAuthDefaultFrontendCB = "/auth/wechat/callback"
wechatOAuthProviderKey = "wechat-main"
+ wechatOAuthLegacyProviderKey = "wechat"
wechatOAuthIntentLogin = "login"
wechatOAuthIntentBind = "bind_current_user"
@@ -483,18 +484,20 @@ func (h *AuthHandler) ensureWeChatBindOwnership(
return infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
}
- identity, err := client.AuthIdentity.Query().
+ identities, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("wechat"),
- authidentity.ProviderKeyEQ(wechatOAuthProviderKey),
+ authidentity.ProviderKeyIn(wechatCompatibleProviderKeys(wechatOAuthProviderKey)...),
authidentity.ProviderSubjectEQ(strings.TrimSpace(providerSubject)),
).
- Only(ctx)
- if err != nil && !dbent.IsNotFound(err) {
+ All(ctx)
+ if err != nil {
return infraerrors.InternalServer("WECHAT_BIND_LOOKUP_FAILED", "failed to inspect wechat identity ownership").WithCause(err)
}
- if identity != nil && identity.UserID != userID {
- return infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
+ for _, identity := range identities {
+ if identity != nil && identity.UserID != userID {
+ return infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
+ }
}
channelSubject = strings.TrimSpace(channelSubject)
@@ -503,21 +506,23 @@ func (h *AuthHandler) ensureWeChatBindOwnership(
return nil
}
- channel, err := client.AuthIdentityChannel.Query().
+ channels, err := client.AuthIdentityChannel.Query().
Where(
authidentitychannel.ProviderTypeEQ("wechat"),
- authidentitychannel.ProviderKeyEQ(wechatOAuthProviderKey),
+ authidentitychannel.ProviderKeyIn(wechatCompatibleProviderKeys(wechatOAuthProviderKey)...),
authidentitychannel.ChannelEQ(strings.TrimSpace(cfg.mode)),
authidentitychannel.ChannelAppIDEQ(channelAppID),
authidentitychannel.ChannelSubjectEQ(channelSubject),
).
WithIdentity().
- Only(ctx)
- if err != nil && !dbent.IsNotFound(err) {
+ All(ctx)
+ if err != nil {
return infraerrors.InternalServer("WECHAT_BIND_CHANNEL_LOOKUP_FAILED", "failed to inspect wechat identity channel ownership").WithCause(err)
}
- if channel != nil && channel.Edges.Identity != nil && channel.Edges.Identity.UserID != userID {
- return infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user")
+ for _, channel := range channels {
+ if channel != nil && channel.Edges.Identity != nil && channel.Edges.Identity.UserID != userID {
+ return infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user")
+ }
}
return nil
}
@@ -533,14 +538,34 @@ func (h *AuthHandler) findWeChatUserByLegacyOpenID(
return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
}
+ providerType := strings.TrimSpace(identity.ProviderType)
+ providerSubject := strings.TrimSpace(identity.ProviderSubject)
+ providerKeys := wechatCompatibleProviderKeys(identity.ProviderKey)
+ if providerSubject != "" {
+ records, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ(providerType),
+ authidentity.ProviderKeyIn(providerKeys...),
+ authidentity.ProviderSubjectEQ(providerSubject),
+ ).
+ WithUser().
+ All(ctx)
+ if err != nil {
+ return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
+ }
+ if user, err := singleWeChatIdentityUser(records); err != nil || user != nil {
+ return user, err
+ }
+ }
+
openid = strings.TrimSpace(openid)
channel := strings.TrimSpace(cfg.mode)
channelAppID := strings.TrimSpace(cfg.appID)
if openid != "" && channel != "" && channelAppID != "" {
- record, err := client.AuthIdentityChannel.Query().
+ records, err := client.AuthIdentityChannel.Query().
Where(
- authidentitychannel.ProviderTypeEQ(strings.TrimSpace(identity.ProviderType)),
- authidentitychannel.ProviderKeyEQ(strings.TrimSpace(identity.ProviderKey)),
+ authidentitychannel.ProviderTypeEQ(providerType),
+ authidentitychannel.ProviderKeyIn(providerKeys...),
authidentitychannel.ChannelEQ(channel),
authidentitychannel.ChannelAppIDEQ(channelAppID),
authidentitychannel.ChannelSubjectEQ(openid),
@@ -548,12 +573,12 @@ func (h *AuthHandler) findWeChatUserByLegacyOpenID(
WithIdentity(func(q *dbent.AuthIdentityQuery) {
q.WithUser()
}).
- Only(ctx)
- if err != nil && !dbent.IsNotFound(err) {
+ All(ctx)
+ if err != nil {
return nil, infraerrors.InternalServer("AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED", "failed to inspect auth identity channel ownership").WithCause(err)
}
- if record != nil && record.Edges.Identity != nil && record.Edges.Identity.Edges.User != nil {
- return record.Edges.Identity.Edges.User, nil
+ if user, err := singleWeChatChannelUser(records); err != nil || user != nil {
+ return user, err
}
}
@@ -561,21 +586,64 @@ func (h *AuthHandler) findWeChatUserByLegacyOpenID(
return nil, nil
}
- record, err := client.AuthIdentity.Query().
+ records, err := client.AuthIdentity.Query().
Where(
- authidentity.ProviderTypeEQ(strings.TrimSpace(identity.ProviderType)),
- authidentity.ProviderKeyEQ(strings.TrimSpace(identity.ProviderKey)),
+ authidentity.ProviderTypeEQ(providerType),
+ authidentity.ProviderKeyIn(providerKeys...),
authidentity.ProviderSubjectEQ(openid),
).
WithUser().
- Only(ctx)
+ All(ctx)
if err != nil {
- if dbent.IsNotFound(err) {
- return nil, nil
- }
return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
}
- return record.Edges.User, nil
+ return singleWeChatIdentityUser(records)
+}
+
+func wechatCompatibleProviderKeys(providerKey string) []string {
+ preferred := strings.TrimSpace(providerKey)
+ if preferred == "" {
+ preferred = wechatOAuthProviderKey
+ }
+ keys := []string{preferred}
+ if !strings.EqualFold(preferred, wechatOAuthLegacyProviderKey) {
+ keys = append(keys, wechatOAuthLegacyProviderKey)
+ }
+ return keys
+}
+
+func singleWeChatIdentityUser(records []*dbent.AuthIdentity) (*dbent.User, error) {
+ var resolved *dbent.User
+ for _, record := range records {
+ if record == nil || record.Edges.User == nil {
+ continue
+ }
+ if resolved == nil {
+ resolved = record.Edges.User
+ continue
+ }
+ if resolved.ID != record.Edges.User.ID {
+ return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
+ }
+ }
+ return resolved, nil
+}
+
+func singleWeChatChannelUser(records []*dbent.AuthIdentityChannel) (*dbent.User, error) {
+ var resolved *dbent.User
+ for _, record := range records {
+ if record == nil || record.Edges.Identity == nil || record.Edges.Identity.Edges.User == nil {
+ continue
+ }
+ if resolved == nil {
+ resolved = record.Edges.Identity.Edges.User
+ continue
+ }
+ if resolved.ID != record.Edges.Identity.Edges.User.ID {
+ return nil, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user")
+ }
+ }
+ return resolved, nil
}
func (h *AuthHandler) ensureWeChatRuntimeIdentityBinding(
diff --git a/backend/internal/handler/auth_wechat_oauth_test.go b/backend/internal/handler/auth_wechat_oauth_test.go
index dd022fb9..def9d5d6 100644
--- a/backend/internal/handler/auth_wechat_oauth_test.go
+++ b/backend/internal/handler/auth_wechat_oauth_test.go
@@ -467,6 +467,88 @@ func TestWeChatOAuthCallbackBindRejectsChannelOwnershipConflict(t *testing.T) {
require.Zero(t, count)
}
+func TestWeChatOAuthCallbackBindRejectsLegacyProviderKeyOwnershipConflict(t *testing.T) {
+ t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
+ t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
+ t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback")
+
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ originalUserInfoURL := wechatOAuthUserInfoURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ wechatOAuthUserInfoURL = originalUserInfoURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
+ case strings.Contains(r.URL.Path, "/sns/userinfo"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Conflict Nick","headimgurl":"https://cdn.example/conflict.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+ wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
+
+ handler, client := newWeChatOAuthTestHandler(t, false)
+ defer client.Close()
+
+ ctx := context.Background()
+ owner, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ currentUser, err := client.User.Create().
+ SetEmail("current@example.com").
+ SetUsername("current").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.AuthIdentity.Create().
+ SetUserID(owner.ID).
+ SetProviderType("wechat").
+ SetProviderKey(wechatOAuthLegacyProviderKey).
+ SetProviderSubject("union-456").
+ SetMetadata(map[string]any{"unionid": "union-456"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
+ req.Host = "api.example.com"
+ req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
+ req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
+ req.AddCookie(encodedCookie(wechatOAuthIntentCookieName, wechatOAuthIntentBind))
+ req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
+ req.AddCookie(encodedCookie(wechatOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret")))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ c.Request = req
+
+ handler.WeChatOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
+ assertOAuthRedirectError(t, recorder.Header().Get("Location"), "ownership_conflict", "AUTH_IDENTITY_OWNERSHIP_CONFLICT")
+
+ count, err := client.PendingAuthSession.Query().Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, count)
+}
+
func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing.T) {
t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
@@ -703,6 +785,116 @@ func TestWeChatOAuthCallbackRepairsLegacyOpenIDOnlyIdentity(t *testing.T) {
require.Equal(t, repairedIdentity.ID, channel.IdentityID)
}
+func TestWeChatOAuthCallbackRepairsLegacyProviderKeyCanonicalIdentity(t *testing.T) {
+ t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
+ t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
+ t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback")
+
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ originalUserInfoURL := wechatOAuthUserInfoURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ wechatOAuthUserInfoURL = originalUserInfoURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
+ case strings.Contains(r.URL.Path, "/sns/userinfo"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Legacy Canonical","headimgurl":"https://cdn.example/legacy-canonical.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+ wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
+
+ handler, client := newWeChatOAuthTestHandler(t, false)
+ defer client.Close()
+
+ ctx := context.Background()
+ legacyUser, err := client.User.Create().
+ SetEmail("legacy@example.com").
+ SetUsername("legacy-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ legacyIdentity, err := client.AuthIdentity.Create().
+ SetUserID(legacyUser.ID).
+ SetProviderType("wechat").
+ SetProviderKey(wechatOAuthLegacyProviderKey).
+ SetProviderSubject("union-456").
+ SetMetadata(map[string]any{"unionid": "union-456"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
+ req.Host = "api.example.com"
+ req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
+ req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
+ req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ c.Request = req
+
+ handler.WeChatOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, session.TargetUserID)
+ require.Equal(t, legacyUser.ID, *session.TargetUserID)
+ require.Equal(t, legacyUser.Email, session.ResolvedEmail)
+
+ repairedIdentity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("wechat"),
+ authidentity.ProviderKeyEQ(wechatOAuthProviderKey),
+ authidentity.ProviderSubjectEQ("union-456"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, legacyIdentity.ID, repairedIdentity.ID)
+ require.Equal(t, legacyUser.ID, repairedIdentity.UserID)
+
+ legacyIdentityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("wechat"),
+ authidentity.ProviderKeyEQ(wechatOAuthLegacyProviderKey),
+ authidentity.ProviderSubjectEQ("union-456"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, legacyIdentityCount)
+
+ channel, err := client.AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ("wechat"),
+ authidentitychannel.ProviderKeyEQ(wechatOAuthProviderKey),
+ authidentitychannel.ChannelEQ("open"),
+ authidentitychannel.ChannelAppIDEQ("wx-open-app"),
+ authidentitychannel.ChannelSubjectEQ("openid-123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, repairedIdentity.ID, channel.IdentityID)
+}
+
func newWeChatOAuthTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) {
t.Helper()
diff --git a/backend/migrations/113_normalize_legacy_wechat_provider_key.sql b/backend/migrations/113_normalize_legacy_wechat_provider_key.sql
new file mode 100644
index 00000000..15610af0
--- /dev/null
+++ b/backend/migrations/113_normalize_legacy_wechat_provider_key.sql
@@ -0,0 +1,89 @@
+UPDATE auth_identities AS ai
+SET
+ provider_key = 'wechat-main',
+ metadata = COALESCE(ai.metadata, '{}'::jsonb) || jsonb_build_object(
+ 'legacy_provider_key', 'wechat',
+ 'normalized_by_migration', '113_normalize_legacy_wechat_provider_key'
+ ),
+ updated_at = NOW()
+WHERE ai.provider_type = 'wechat'
+ AND ai.provider_key = 'wechat'
+ AND NOT EXISTS (
+ SELECT 1
+ FROM auth_identities AS canon
+ WHERE canon.provider_type = 'wechat'
+ AND canon.provider_key = 'wechat-main'
+ AND canon.provider_subject = ai.provider_subject
+ );
+
+UPDATE auth_identity_channels AS channel
+SET
+ provider_key = 'wechat-main',
+ metadata = COALESCE(channel.metadata, '{}'::jsonb) || jsonb_build_object(
+ 'legacy_provider_key', 'wechat',
+ 'normalized_by_migration', '113_normalize_legacy_wechat_provider_key'
+ ),
+ updated_at = NOW()
+WHERE channel.provider_type = 'wechat'
+ AND channel.provider_key = 'wechat'
+ AND NOT EXISTS (
+ SELECT 1
+ FROM auth_identity_channels AS canon
+ WHERE canon.provider_type = 'wechat'
+ AND canon.provider_key = 'wechat-main'
+ AND canon.channel = channel.channel
+ AND canon.channel_app_id = channel.channel_app_id
+ AND canon.channel_subject = channel.channel_subject
+ );
+
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'wechat_provider_key_conflict',
+ CAST(ai.id AS TEXT),
+ jsonb_build_object(
+ 'legacy_identity_id', ai.id,
+ 'legacy_user_id', ai.user_id,
+ 'provider_subject', ai.provider_subject,
+ 'canonical_identity_id', canon.id,
+ 'canonical_user_id', canon.user_id,
+ 'same_user', canon.user_id = ai.user_id,
+ 'migration', '113_normalize_legacy_wechat_provider_key'
+ )
+FROM auth_identities AS ai
+JOIN auth_identities AS canon
+ ON canon.provider_type = 'wechat'
+ AND canon.provider_key = 'wechat-main'
+ AND canon.provider_subject = ai.provider_subject
+WHERE ai.provider_type = 'wechat'
+ AND ai.provider_key = 'wechat'
+ON CONFLICT (report_type, report_key) DO NOTHING;
+
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'wechat_channel_provider_key_conflict',
+ CAST(channel.id AS TEXT),
+ jsonb_build_object(
+ 'legacy_channel_id', channel.id,
+ 'legacy_identity_id', channel.identity_id,
+ 'canonical_channel_id', canon.id,
+ 'canonical_identity_id', canon.identity_id,
+ 'channel', channel.channel,
+ 'channel_app_id', channel.channel_app_id,
+ 'channel_subject', channel.channel_subject,
+ 'same_user', COALESCE(legacy_identity.user_id = canonical_identity.user_id, FALSE),
+ 'migration', '113_normalize_legacy_wechat_provider_key'
+ )
+FROM auth_identity_channels AS channel
+JOIN auth_identity_channels AS canon
+ ON canon.provider_type = 'wechat'
+ AND canon.provider_key = 'wechat-main'
+ AND canon.channel = channel.channel
+ AND canon.channel_app_id = channel.channel_app_id
+ AND canon.channel_subject = channel.channel_subject
+LEFT JOIN auth_identities AS legacy_identity
+ ON legacy_identity.id = channel.identity_id
+LEFT JOIN auth_identities AS canonical_identity
+ ON canonical_identity.id = canon.identity_id
+WHERE channel.provider_type = 'wechat'
+ AND channel.provider_key = 'wechat'
+ON CONFLICT (report_type, report_key) DO NOTHING;
--
GitLab
From b30982219962c96d4e7c33d6aaeff4171eccbe88 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Mon, 20 Apr 2026 21:46:24 +0800
Subject: [PATCH 058/261] fix: tighten legacy payment provider resolution
---
backend/internal/service/payment_refund.go | 36 +++++++----
.../service/payment_webhook_provider_test.go | 64 +++++++++++++++++++
2 files changed, 86 insertions(+), 14 deletions(-)
diff --git a/backend/internal/service/payment_refund.go b/backend/internal/service/payment_refund.go
index 01eecff8..57469fa3 100644
--- a/backend/internal/service/payment_refund.go
+++ b/backend/internal/service/payment_refund.go
@@ -21,7 +21,7 @@ import (
// getOrderProviderInstance looks up the provider instance that processed this order.
// For legacy orders without provider_instance_id, it resolves only when the
-// enabled instance is uniquely identifiable from the stored order fields.
+// historical instance is uniquely identifiable from the stored order fields.
func (s *PaymentService) getOrderProviderInstance(ctx context.Context, o *dbent.PaymentOrder) (*dbent.PaymentProviderInstance, error) {
if s == nil || s.entClient == nil || o == nil {
return nil, nil
@@ -40,45 +40,53 @@ func (s *PaymentService) getOrderProviderInstance(ctx context.Context, o *dbent.
}
func (s *PaymentService) resolveUniqueLegacyOrderProviderInstance(ctx context.Context, o *dbent.PaymentOrder) (*dbent.PaymentProviderInstance, error) {
+ paymentType := payment.GetBasePaymentType(strings.TrimSpace(o.PaymentType))
providerKey := strings.TrimSpace(psStringValue(o.ProviderKey))
if providerKey != "" {
instances, err := s.entClient.PaymentProviderInstance.Query().
- Where(
- paymentproviderinstance.EnabledEQ(true),
- paymentproviderinstance.ProviderKeyEQ(providerKey),
- ).
+ Where(paymentproviderinstance.ProviderKeyEQ(providerKey)).
All(ctx)
if err != nil {
return nil, err
}
- if len(instances) == 1 {
- return instances[0], nil
+ matched := psFilterLegacyOrderProviderInstances(paymentType, instances)
+ if len(matched) == 1 {
+ return matched[0], nil
}
return nil, nil
}
- paymentType := payment.GetBasePaymentType(strings.TrimSpace(o.PaymentType))
if paymentType == "" {
return nil, nil
}
instances, err := s.entClient.PaymentProviderInstance.Query().
- Where(paymentproviderinstance.EnabledEQ(true)).
All(ctx)
if err != nil {
return nil, err
}
+ matched := psFilterLegacyOrderProviderInstances(paymentType, instances)
+ if len(matched) == 1 {
+ return matched[0], nil
+ }
+ return nil, nil
+}
+
+func psFilterLegacyOrderProviderInstances(orderPaymentType string, instances []*dbent.PaymentProviderInstance) []*dbent.PaymentProviderInstance {
+ if len(instances) == 0 {
+ return nil
+ }
+ if strings.TrimSpace(orderPaymentType) == "" {
+ return instances
+ }
var matched []*dbent.PaymentProviderInstance
for _, inst := range instances {
- if psLegacyOrderMatchesInstance(paymentType, inst) {
+ if psLegacyOrderMatchesInstance(orderPaymentType, inst) {
matched = append(matched, inst)
}
}
- if len(matched) == 1 {
- return matched[0], nil
- }
- return nil, nil
+ return matched
}
func psLegacyOrderMatchesInstance(orderPaymentType string, inst *dbent.PaymentProviderInstance) bool {
diff --git a/backend/internal/service/payment_webhook_provider_test.go b/backend/internal/service/payment_webhook_provider_test.go
index 33e4186d..4f0b6848 100644
--- a/backend/internal/service/payment_webhook_provider_test.go
+++ b/backend/internal/service/payment_webhook_provider_test.go
@@ -141,6 +141,70 @@ func TestGetOrderProviderInstanceLeavesAmbiguousLegacyOrderUnresolved(t *testing
require.Nil(t, got)
}
+func TestGetOrderProviderInstanceLeavesLegacyProviderKeyUnresolvedWhenHistoricalInstancesConflict(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ _, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeStripe).
+ SetName("stripe-disabled-legacy").
+ SetConfig("{}").
+ SetSupportedTypes("stripe").
+ SetEnabled(false).
+ Save(ctx)
+ require.NoError(t, err)
+ _, err = client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeStripe).
+ SetName("stripe-enabled-current").
+ SetConfig("{}").
+ SetSupportedTypes("stripe").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ providerKey := payment.TypeStripe
+ order := &dbent.PaymentOrder{
+ PaymentType: payment.TypeStripe,
+ ProviderKey: &providerKey,
+ }
+
+ svc := &PaymentService{
+ entClient: client,
+ loadBalancer: newWebhookProviderTestLoadBalancer(client),
+ }
+
+ got, err := svc.getOrderProviderInstance(ctx, order)
+ require.NoError(t, err)
+ require.Nil(t, got)
+}
+
+func TestGetOrderProviderInstanceLeavesProviderKeyMatchUnresolvedWhenTypeNotSupported(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ _, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeWxpay).
+ SetName("wxpay-only").
+ SetConfig("{}").
+ SetSupportedTypes("wxpay").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ providerKey := payment.TypeWxpay
+ order := &dbent.PaymentOrder{
+ PaymentType: payment.TypeAlipayDirect,
+ ProviderKey: &providerKey,
+ }
+
+ svc := &PaymentService{
+ entClient: client,
+ loadBalancer: newWebhookProviderTestLoadBalancer(client),
+ }
+
+ got, err := svc.getOrderProviderInstance(ctx, order)
+ require.NoError(t, err)
+ require.Nil(t, got)
+}
+
func TestGetWebhookProviderRejectsAmbiguousRegistryFallback(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
--
GitLab
From 31d0183d45b3c5d2a6692daebec58625a393fe38 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Mon, 20 Apr 2026 21:51:57 +0800
Subject: [PATCH 059/261] fix: normalize repository email lookups
---
backend/internal/repository/user_repo.go | 30 +++++++-
.../user_repo_email_lookup_unit_test.go | 69 +++++++++++++++++++
2 files changed, 96 insertions(+), 3 deletions(-)
create mode 100644 backend/internal/repository/user_repo_email_lookup_unit_test.go
diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go
index 0c607ecc..b5efd19d 100644
--- a/backend/internal/repository/user_repo.go
+++ b/backend/internal/repository/user_repo.go
@@ -12,6 +12,7 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/apikey"
dbgroup "github.com/Wei-Shaw/sub2api/ent/group"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
@@ -104,10 +105,20 @@ func (r *userRepository) GetByID(ctx context.Context, id int64) (*service.User,
}
func (r *userRepository) GetByEmail(ctx context.Context, email string) (*service.User, error) {
- m, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Only(ctx)
+ matches, err := r.client.User.Query().
+ Where(userEmailLookupPredicate(email)).
+ Order(dbent.Asc(dbuser.FieldID)).
+ All(ctx)
if err != nil {
- return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
+ return nil, err
+ }
+ if len(matches) == 0 {
+ return nil, service.ErrUserNotFound
}
+ if len(matches) > 1 {
+ return nil, fmt.Errorf("normalized email lookup matched multiple users for %q", strings.TrimSpace(email))
+ }
+ m := matches[0]
out := userEntityToService(m)
groups, err := r.loadAllowedGroups(ctx, []int64{m.ID})
@@ -469,7 +480,20 @@ func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount
}
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
- return r.client.User.Query().Where(dbuser.EmailEQ(email)).Exist(ctx)
+ return r.client.User.Query().Where(userEmailLookupPredicate(email)).Exist(ctx)
+}
+
+func userEmailLookupPredicate(email string) predicate.User {
+ normalized := strings.TrimSpace(email)
+ if normalized == "" {
+ return dbuser.EmailEQ(email)
+ }
+ return predicate.User(func(s *entsql.Selector) {
+ s.Where(entsql.ExprP(
+ fmt.Sprintf("LOWER(TRIM(%s)) = LOWER(TRIM(?))", s.C(dbuser.FieldEmail)),
+ normalized,
+ ))
+ })
}
func (r *userRepository) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
diff --git a/backend/internal/repository/user_repo_email_lookup_unit_test.go b/backend/internal/repository/user_repo_email_lookup_unit_test.go
new file mode 100644
index 00000000..d42ce9ac
--- /dev/null
+++ b/backend/internal/repository/user_repo_email_lookup_unit_test.go
@@ -0,0 +1,69 @@
+package repository
+
+import (
+ "context"
+ "database/sql"
+ "testing"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/stretchr/testify/require"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
+)
+
+func newUserEntRepo(t *testing.T) (*userRepository, *dbent.Client) {
+ t.Helper()
+
+ db, err := sql.Open("sqlite", "file:user_repo_email_lookup?mode=memory&cache=shared")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+ t.Cleanup(func() { _ = client.Close() })
+
+ return newUserRepositoryWithSQL(client, db), client
+}
+
+func TestUserRepositoryGetByEmailNormalizesLegacySpacingAndCase(t *testing.T) {
+ repo, _ := newUserEntRepo(t)
+ ctx := context.Background()
+
+ err := repo.Create(ctx, &service.User{
+ Email: " Legacy@Example.com ",
+ Username: "legacy-user",
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ })
+ require.NoError(t, err)
+
+ got, err := repo.GetByEmail(ctx, "legacy@example.com")
+ require.NoError(t, err)
+ require.Equal(t, " Legacy@Example.com ", got.Email)
+}
+
+func TestUserRepositoryExistsByEmailNormalizesLegacySpacingAndCase(t *testing.T) {
+ repo, _ := newUserEntRepo(t)
+ ctx := context.Background()
+
+ err := repo.Create(ctx, &service.User{
+ Email: " Legacy@Example.com ",
+ Username: "legacy-user",
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ })
+ require.NoError(t, err)
+
+ exists, err := repo.ExistsByEmail(ctx, " LEGACY@example.com ")
+ require.NoError(t, err)
+ require.True(t, exists)
+}
--
GitLab
From aaf4946b27ebc03163d0832b7946234f44b6b769 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Mon, 20 Apr 2026 21:59:03 +0800
Subject: [PATCH 060/261] fix: normalize pending oauth email lookups
---
.../handler/auth_oauth_pending_flow.go | 47 ++++++++--
.../handler/auth_oauth_pending_flow_test.go | 85 +++++++++++++++++++
2 files changed, 126 insertions(+), 6 deletions(-)
diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go
index fd35e4e5..1b3c1380 100644
--- a/backend/internal/handler/auth_oauth_pending_flow.go
+++ b/backend/internal/handler/auth_oauth_pending_flow.go
@@ -3,6 +3,7 @@ package handler
import (
"context"
"errors"
+ "fmt"
"io"
"net/http"
"net/url"
@@ -12,12 +13,14 @@ import (
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
+ entsql "entgo.io/ent/dialect/sql"
"github.com/gin-gonic/gin"
)
@@ -531,11 +534,9 @@ func resolvePendingOAuthTargetUserID(ctx context.Context, client *dbent.Client,
return 0, infraerrors.BadRequest("PENDING_AUTH_TARGET_USER_MISSING", "pending auth target user is missing")
}
- userEntity, err := client.User.Query().
- Where(dbuser.EmailEQ(email)).
- Only(ctx)
+ userEntity, err := findUserByNormalizedEmail(ctx, client, email)
if err != nil {
- if dbent.IsNotFound(err) {
+ if errors.Is(err, service.ErrUserNotFound) {
return 0, infraerrors.InternalServer("PENDING_AUTH_TARGET_USER_NOT_FOUND", "pending auth target user was not found")
}
return 0, err
@@ -543,6 +544,40 @@ func resolvePendingOAuthTargetUserID(ctx context.Context, client *dbent.Client,
return userEntity.ID, nil
}
+func userNormalizedEmailPredicate(email string) predicate.User {
+ normalized := strings.TrimSpace(email)
+ if normalized == "" {
+ return dbuser.EmailEQ(email)
+ }
+ return predicate.User(func(s *entsql.Selector) {
+ s.Where(entsql.ExprP(
+ fmt.Sprintf("LOWER(TRIM(%s)) = LOWER(TRIM(?))", s.C(dbuser.FieldEmail)),
+ normalized,
+ ))
+ })
+}
+
+func findUserByNormalizedEmail(ctx context.Context, client *dbent.Client, email string) (*dbent.User, error) {
+ if client == nil {
+ return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+
+ matches, err := client.User.Query().
+ Where(userNormalizedEmailPredicate(email)).
+ Order(dbent.Asc(dbuser.FieldID)).
+ All(ctx)
+ if err != nil {
+ return nil, err
+ }
+ if len(matches) == 0 {
+ return nil, service.ErrUserNotFound
+ }
+ if len(matches) > 1 {
+ return nil, infraerrors.Conflict("USER_EMAIL_CONFLICT", "normalized email matched multiple users")
+ }
+ return matches[0], nil
+}
+
func oauthIdentityIssuer(session *dbent.PendingAuthSession) *string {
if session == nil {
return nil
@@ -1102,8 +1137,8 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
}
email := strings.TrimSpace(strings.ToLower(req.Email))
- existingUser, err := client.User.Query().Where(dbuser.EmailEQ(email)).Only(c.Request.Context())
- if err != nil && !dbent.IsNotFound(err) {
+ existingUser, err := findUserByNormalizedEmail(c.Request.Context(), client, email)
+ if err != nil && !errors.Is(err, service.ErrUserNotFound) {
response.ErrorFrom(c, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable"))
return
}
diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go
index 89accd60..8c468fdc 100644
--- a/backend/internal/handler/auth_oauth_pending_flow_test.go
+++ b/backend/internal/handler/auth_oauth_pending_flow_test.go
@@ -642,6 +642,60 @@ func TestCreateOIDCOAuthAccountExistingEmailReturnsAdoptExistingUserByEmailState
require.Zero(t, identityCount)
}
+func TestCreateOIDCOAuthAccountExistingEmailNormalizesLegacySpacingAndCase(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790")
+ ctx := context.Background()
+
+ existingUser, err := client.User.Create().
+ SetEmail(" Owner@Example.com ").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("existing-email-normalized-session-token").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-existing-normalized-123").
+ SetBrowserSessionKey("existing-email-normalized-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ "suggested_display_name": "Existing OIDC User",
+ "suggested_avatar_url": "https://cdn.example/existing.png",
+ }).
+ SetRedirectTo("/dashboard").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"email":"owner@example.com","verify_code":"135790","password":"secret-123"}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("existing-email-normalized-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.CreateOIDCOAuthAccount(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var payload map[string]any
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload))
+ require.Equal(t, "adopt_existing_user_by_email", payload["intent"])
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.NotNil(t, storedSession.TargetUserID)
+ require.Equal(t, existingUser.ID, *storedSession.TargetUserID)
+ require.Equal(t, "owner@example.com", storedSession.ResolvedEmail)
+}
+
func TestBindOIDCOAuthLoginBindsExistingUserAndConsumesSession(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
@@ -884,6 +938,37 @@ func TestBindOIDCOAuthLoginAppliesFirstBindGrantOnce(t *testing.T) {
require.Equal(t, 1, countProviderGrantRecords(t, client, existingUser.ID, "oidc", "first_bind"))
}
+func TestResolvePendingOAuthTargetUserIDNormalizesLegacySpacingAndCase(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ _ = handler
+ ctx := context.Background()
+
+ existingUser, err := client.User.Create().
+ SetEmail(" Owner@Example.com ").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("resolve-target-session-token").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-target-123").
+ SetResolvedEmail("owner@example.com").
+ SetBrowserSessionKey("resolve-target-browser-session-key").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ resolvedUserID, err := resolvePendingOAuthTargetUserID(ctx, client, session)
+ require.NoError(t, err)
+ require.Equal(t, existingUser.ID, resolvedUserID)
+}
+
func TestBindOIDCOAuthLoginReturns2FAChallengeWhenUserHasTotp(t *testing.T) {
totpCache := &oauthPendingFlowTotpCacheStub{}
handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
--
GitLab
From bbc4aed3d91c8b3312c320843e3d4bea54e17d80 Mon Sep 17 00:00:00 2001
From: erio
Date: Mon, 20 Apr 2026 22:01:09 +0800
Subject: [PATCH 061/261] =?UTF-8?q?fix(openai):=20=E7=A7=BB=E9=99=A4?=
=?UTF-8?q?=E5=B7=B2=E4=B8=8B=E7=BA=BF=20Codex=20=E6=A8=A1=E5=9E=8B?=
=?UTF-8?q?=E5=B9=B6=E4=BF=AE=E5=A4=8D=E5=BD=92=E4=B8=80=E5=8C=96=E5=85=9C?=
=?UTF-8?q?=E5=BA=95=E5=89=AF=E4=BD=9C=E7=94=A8?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
- backend: 删除 gpt-5 / 5.1 / 5.1-codex / 5.1-codex-max / 5.1-codex-mini / 5.2-codex / 5.4-nano 的内置映射与 DefaultModels 条目
- backend: normalizeCodexModel 默认兜底由 gpt-5.1 改为 gpt-5.4,gpt-5.3-codex-spark 独立保留映射
- backend: 修复 isOpenAIGPT54Model 与 shouldAutoInjectPromptCacheKeyForCompat 对 claude / gpt-4o 的误判(之前依赖 gpt-5.1 作为非 GPT 族的隐式 sentinel,改后需要显式前缀守卫)
- backend: 清理 billing_service 中已不可达的 fallback 价格与 switch 分支
- frontend: 从白名单、OpenCode 配置、预设映射中移除已下线模型
- 同步更新所有相关单测
Refs: #1758, parallels upstream #1759 but adds downstream guard fixes
---
backend/internal/pkg/openai/constants.go | 9 +-
backend/internal/service/billing_service.go | 51 ++--------
.../internal/service/billing_service_test.go | 29 +-----
.../service/openai_codex_transform.go | 76 +++------------
.../service/openai_codex_transform_test.go | 30 ++++--
.../service/openai_compat_prompt_cache_key.go | 10 +-
.../service/openai_model_mapping_test.go | 14 +--
frontend/src/components/keys/UseKeyModal.vue | 92 -------------------
.../keys/__tests__/UseKeyModal.spec.ts | 4 +-
.../__tests__/useModelWhitelist.spec.ts | 19 +++-
frontend/src/composables/useModelWhitelist.ts | 15 +--
11 files changed, 84 insertions(+), 265 deletions(-)
diff --git a/backend/internal/pkg/openai/constants.go b/backend/internal/pkg/openai/constants.go
index 49e38bf8..980f058d 100644
--- a/backend/internal/pkg/openai/constants.go
+++ b/backend/internal/pkg/openai/constants.go
@@ -17,16 +17,9 @@ type Model struct {
var DefaultModels = []Model{
{ID: "gpt-5.4", Object: "model", Created: 1738368000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.4"},
{ID: "gpt-5.4-mini", Object: "model", Created: 1738368000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.4 Mini"},
- {ID: "gpt-5.4-nano", Object: "model", Created: 1738368000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.4 Nano"},
{ID: "gpt-5.3-codex", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex"},
{ID: "gpt-5.3-codex-spark", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex Spark"},
{ID: "gpt-5.2", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2"},
- {ID: "gpt-5.2-codex", Object: "model", Created: 1733011200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2 Codex"},
- {ID: "gpt-5.1-codex-max", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Max"},
- {ID: "gpt-5.1-codex", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex"},
- {ID: "gpt-5.1", Object: "model", Created: 1731456000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1"},
- {ID: "gpt-5.1-codex-mini", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Mini"},
- {ID: "gpt-5", Object: "model", Created: 1722988800, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5"},
}
// DefaultModelIDs returns the default model ID list
@@ -39,7 +32,7 @@ func DefaultModelIDs() []string {
}
// DefaultTestModel default model for testing OpenAI accounts
-const DefaultTestModel = "gpt-5.1-codex"
+const DefaultTestModel = "gpt-5.4"
// DefaultInstructions default instructions for non-Codex CLI requests
// Content loaded from instructions.txt at compile time
diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go
index c9f32b3b..a45203a3 100644
--- a/backend/internal/service/billing_service.go
+++ b/backend/internal/service/billing_service.go
@@ -203,17 +203,6 @@ func (s *BillingService) initFallbackPricing() {
SupportsCacheBreakdown: false,
}
- // OpenAI GPT-5.1(本地兜底,防止动态定价不可用时拒绝计费)
- s.fallbackPrices["gpt-5.1"] = &ModelPricing{
- InputPricePerToken: 1.25e-6, // $1.25 per MTok
- InputPricePerTokenPriority: 2.5e-6, // $2.5 per MTok
- OutputPricePerToken: 10e-6, // $10 per MTok
- OutputPricePerTokenPriority: 20e-6, // $20 per MTok
- CacheCreationPricePerToken: 1.25e-6, // $1.25 per MTok
- CacheReadPricePerToken: 0.125e-6,
- CacheReadPricePerTokenPriority: 0.25e-6,
- SupportsCacheBreakdown: false,
- }
// OpenAI GPT-5.4(业务指定价格)
s.fallbackPrices["gpt-5.4"] = &ModelPricing{
InputPricePerToken: 2.5e-6, // $2.5 per MTok
@@ -234,12 +223,6 @@ func (s *BillingService) initFallbackPricing() {
CacheReadPricePerToken: 7.5e-8,
SupportsCacheBreakdown: false,
}
- s.fallbackPrices["gpt-5.4-nano"] = &ModelPricing{
- InputPricePerToken: 2e-7,
- OutputPricePerToken: 1.25e-6,
- CacheReadPricePerToken: 2e-8,
- SupportsCacheBreakdown: false,
- }
// OpenAI GPT-5.2(本地兜底)
s.fallbackPrices["gpt-5.2"] = &ModelPricing{
InputPricePerToken: 1.75e-6,
@@ -251,8 +234,8 @@ func (s *BillingService) initFallbackPricing() {
CacheReadPricePerTokenPriority: 0.35e-6,
SupportsCacheBreakdown: false,
}
- // Codex 族兜底统一按 GPT-5.1 Codex 价格计费
- s.fallbackPrices["gpt-5.1-codex"] = &ModelPricing{
+ // Codex 族兜底统一按 GPT-5.3 Codex 价格计费
+ s.fallbackPrices["gpt-5.3-codex"] = &ModelPricing{
InputPricePerToken: 1.5e-6, // $1.5 per MTok
InputPricePerTokenPriority: 3e-6, // $3 per MTok
OutputPricePerToken: 12e-6, // $12 per MTok
@@ -262,17 +245,6 @@ func (s *BillingService) initFallbackPricing() {
CacheReadPricePerTokenPriority: 0.3e-6,
SupportsCacheBreakdown: false,
}
- s.fallbackPrices["gpt-5.2-codex"] = &ModelPricing{
- InputPricePerToken: 1.75e-6,
- InputPricePerTokenPriority: 3.5e-6,
- OutputPricePerToken: 14e-6,
- OutputPricePerTokenPriority: 28e-6,
- CacheCreationPricePerToken: 1.75e-6,
- CacheReadPricePerToken: 0.175e-6,
- CacheReadPricePerTokenPriority: 0.35e-6,
- SupportsCacheBreakdown: false,
- }
- s.fallbackPrices["gpt-5.3-codex"] = s.fallbackPrices["gpt-5.1-codex"]
}
// getFallbackPricing 根据模型系列获取回退价格
@@ -318,20 +290,12 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing {
switch normalized {
case "gpt-5.4-mini":
return s.fallbackPrices["gpt-5.4-mini"]
- case "gpt-5.4-nano":
- return s.fallbackPrices["gpt-5.4-nano"]
case "gpt-5.4":
return s.fallbackPrices["gpt-5.4"]
case "gpt-5.2":
return s.fallbackPrices["gpt-5.2"]
- case "gpt-5.2-codex":
- return s.fallbackPrices["gpt-5.2-codex"]
- case "gpt-5.3-codex":
+ case "gpt-5.3-codex", "gpt-5.3-codex-spark":
return s.fallbackPrices["gpt-5.3-codex"]
- case "gpt-5.1-codex", "gpt-5.1-codex-max", "gpt-5.1-codex-mini", "codex-mini-latest":
- return s.fallbackPrices["gpt-5.1-codex"]
- case "gpt-5.1":
- return s.fallbackPrices["gpt-5.1"]
}
}
@@ -667,8 +631,13 @@ func (s *BillingService) shouldApplySessionLongContextPricing(tokens UsageTokens
}
func isOpenAIGPT54Model(model string) bool {
- normalized := normalizeCodexModel(strings.TrimSpace(strings.ToLower(model)))
- return normalized == "gpt-5.4"
+ trimmed := strings.TrimSpace(strings.ToLower(model))
+ // 仅当模型字符串实际属于 GPT-5/Codex 族时才做归一判定,避免 normalizeCodexModel
+ // 的默认兜底把非 OpenAI 模型(claude-*、gemini-*、gpt-4o)误识别为 gpt-5.4。
+ if !strings.Contains(trimmed, "gpt-5") && !strings.Contains(trimmed, "codex") {
+ return false
+ }
+ return normalizeCodexModel(trimmed) == "gpt-5.4"
}
// CalculateCostWithConfig 使用配置中的默认倍率计算费用
diff --git a/backend/internal/service/billing_service_test.go b/backend/internal/service/billing_service_test.go
index fc8361c7..222abd69 100644
--- a/backend/internal/service/billing_service_test.go
+++ b/backend/internal/service/billing_service_test.go
@@ -123,15 +123,6 @@ func TestGetModelPricing_UnknownOpenAIModelReturnsError(t *testing.T) {
require.Contains(t, err.Error(), "pricing not found")
}
-func TestGetModelPricing_OpenAIGPT51Fallback(t *testing.T) {
- svc := newTestBillingService()
-
- pricing, err := svc.GetModelPricing("gpt-5.1")
- require.NoError(t, err)
- require.NotNil(t, pricing)
- require.InDelta(t, 1.25e-6, pricing.InputPricePerToken, 1e-12)
-}
-
func TestGetModelPricing_OpenAIGPT54Fallback(t *testing.T) {
svc := newTestBillingService()
@@ -158,18 +149,6 @@ func TestGetModelPricing_OpenAIGPT54MiniFallback(t *testing.T) {
require.Zero(t, pricing.LongContextInputThreshold)
}
-func TestGetModelPricing_OpenAIGPT54NanoFallback(t *testing.T) {
- svc := newTestBillingService()
-
- pricing, err := svc.GetModelPricing("gpt-5.4-nano")
- require.NoError(t, err)
- require.NotNil(t, pricing)
- require.InDelta(t, 2e-7, pricing.InputPricePerToken, 1e-12)
- require.InDelta(t, 1.25e-6, pricing.OutputPricePerToken, 1e-12)
- require.InDelta(t, 2e-8, pricing.CacheReadPricePerToken, 1e-12)
- require.Zero(t, pricing.LongContextInputThreshold)
-}
-
func TestCalculateCost_OpenAIGPT54LongContextAppliesWholeSessionMultipliers(t *testing.T) {
svc := newTestBillingService()
@@ -204,13 +183,13 @@ func TestGetFallbackPricing_FamilyMatching(t *testing.T) {
{name: "claude generic model fallback sonnet", model: "claude-foo-bar", expectedInput: 3e-6},
{name: "gemini explicit fallback", model: "gemini-3-1-pro", expectedInput: 2e-6},
{name: "gemini unknown no fallback", model: "gemini-2.0-pro", expectNilPricing: true},
- {name: "openai gpt5.1", model: "gpt-5.1", expectedInput: 1.25e-6},
{name: "openai gpt5.4", model: "gpt-5.4", expectedInput: 2.5e-6},
{name: "openai gpt5.4 mini", model: "gpt-5.4-mini", expectedInput: 7.5e-7},
- {name: "openai gpt5.4 nano", model: "gpt-5.4-nano", expectedInput: 2e-7},
{name: "openai gpt5.3 codex", model: "gpt-5.3-codex", expectedInput: 1.5e-6},
- {name: "openai gpt5.1 codex max alias", model: "gpt-5.1-codex-max", expectedInput: 1.5e-6},
- {name: "openai codex mini latest alias", model: "codex-mini-latest", expectedInput: 1.5e-6},
+ {name: "openai gpt5.3 codex spark", model: "gpt-5.3-codex-spark", expectedInput: 1.5e-6},
+ {name: "openai legacy gpt5.1 falls back to gpt5.4", model: "gpt-5.1", expectedInput: 2.5e-6},
+ {name: "openai legacy gpt5.1 codex falls back to gpt5.3 codex", model: "gpt-5.1-codex", expectedInput: 1.5e-6},
+ {name: "openai legacy codex mini latest falls back to gpt5.3 codex", model: "codex-mini-latest", expectedInput: 1.5e-6},
{name: "openai unknown no fallback", model: "gpt-unknown-model", expectNilPricing: true},
{name: "non supported family", model: "qwen-max", expectNilPricing: true},
}
diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go
index a266d6a0..457309d3 100644
--- a/backend/internal/service/openai_codex_transform.go
+++ b/backend/internal/service/openai_codex_transform.go
@@ -8,7 +8,6 @@ import (
var codexModelMap = map[string]string{
"gpt-5.4": "gpt-5.4",
"gpt-5.4-mini": "gpt-5.4-mini",
- "gpt-5.4-nano": "gpt-5.4-nano",
"gpt-5.4-none": "gpt-5.4",
"gpt-5.4-low": "gpt-5.4",
"gpt-5.4-medium": "gpt-5.4",
@@ -22,52 +21,21 @@ var codexModelMap = map[string]string{
"gpt-5.3-high": "gpt-5.3-codex",
"gpt-5.3-xhigh": "gpt-5.3-codex",
"gpt-5.3-codex": "gpt-5.3-codex",
- "gpt-5.3-codex-spark": "gpt-5.3-codex",
- "gpt-5.3-codex-spark-low": "gpt-5.3-codex",
- "gpt-5.3-codex-spark-medium": "gpt-5.3-codex",
- "gpt-5.3-codex-spark-high": "gpt-5.3-codex",
- "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex",
+ "gpt-5.3-codex-spark": "gpt-5.3-codex-spark",
+ "gpt-5.3-codex-spark-low": "gpt-5.3-codex-spark",
+ "gpt-5.3-codex-spark-medium": "gpt-5.3-codex-spark",
+ "gpt-5.3-codex-spark-high": "gpt-5.3-codex-spark",
+ "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex-spark",
"gpt-5.3-codex-low": "gpt-5.3-codex",
"gpt-5.3-codex-medium": "gpt-5.3-codex",
"gpt-5.3-codex-high": "gpt-5.3-codex",
"gpt-5.3-codex-xhigh": "gpt-5.3-codex",
- "gpt-5.1-codex": "gpt-5.1-codex",
- "gpt-5.1-codex-low": "gpt-5.1-codex",
- "gpt-5.1-codex-medium": "gpt-5.1-codex",
- "gpt-5.1-codex-high": "gpt-5.1-codex",
- "gpt-5.1-codex-max": "gpt-5.1-codex-max",
- "gpt-5.1-codex-max-low": "gpt-5.1-codex-max",
- "gpt-5.1-codex-max-medium": "gpt-5.1-codex-max",
- "gpt-5.1-codex-max-high": "gpt-5.1-codex-max",
- "gpt-5.1-codex-max-xhigh": "gpt-5.1-codex-max",
"gpt-5.2": "gpt-5.2",
"gpt-5.2-none": "gpt-5.2",
"gpt-5.2-low": "gpt-5.2",
"gpt-5.2-medium": "gpt-5.2",
"gpt-5.2-high": "gpt-5.2",
"gpt-5.2-xhigh": "gpt-5.2",
- "gpt-5.2-codex": "gpt-5.2-codex",
- "gpt-5.2-codex-low": "gpt-5.2-codex",
- "gpt-5.2-codex-medium": "gpt-5.2-codex",
- "gpt-5.2-codex-high": "gpt-5.2-codex",
- "gpt-5.2-codex-xhigh": "gpt-5.2-codex",
- "gpt-5.1-codex-mini": "gpt-5.1-codex-mini",
- "gpt-5.1-codex-mini-medium": "gpt-5.1-codex-mini",
- "gpt-5.1-codex-mini-high": "gpt-5.1-codex-mini",
- "gpt-5.1": "gpt-5.1",
- "gpt-5.1-none": "gpt-5.1",
- "gpt-5.1-low": "gpt-5.1",
- "gpt-5.1-medium": "gpt-5.1",
- "gpt-5.1-high": "gpt-5.1",
- "gpt-5.1-chat-latest": "gpt-5.1",
- "gpt-5-codex": "gpt-5.1-codex",
- "codex-mini-latest": "gpt-5.1-codex-mini",
- "gpt-5-codex-mini": "gpt-5.1-codex-mini",
- "gpt-5-codex-mini-medium": "gpt-5.1-codex-mini",
- "gpt-5-codex-mini-high": "gpt-5.1-codex-mini",
- "gpt-5": "gpt-5.1",
- "gpt-5-mini": "gpt-5.1",
- "gpt-5-nano": "gpt-5.1",
}
type codexTransformResult struct {
@@ -220,7 +188,7 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
func normalizeCodexModel(model string) string {
if model == "" {
- return "gpt-5.1"
+ return "gpt-5.4"
}
modelID := model
@@ -238,49 +206,29 @@ func normalizeCodexModel(model string) string {
if strings.Contains(normalized, "gpt-5.4-mini") || strings.Contains(normalized, "gpt 5.4 mini") {
return "gpt-5.4-mini"
}
- if strings.Contains(normalized, "gpt-5.4-nano") || strings.Contains(normalized, "gpt 5.4 nano") {
- return "gpt-5.4-nano"
- }
if strings.Contains(normalized, "gpt-5.4") || strings.Contains(normalized, "gpt 5.4") {
return "gpt-5.4"
}
- if strings.Contains(normalized, "gpt-5.2-codex") || strings.Contains(normalized, "gpt 5.2 codex") {
- return "gpt-5.2-codex"
- }
if strings.Contains(normalized, "gpt-5.2") || strings.Contains(normalized, "gpt 5.2") {
return "gpt-5.2"
}
+ if strings.Contains(normalized, "gpt-5.3-codex-spark") || strings.Contains(normalized, "gpt 5.3 codex spark") {
+ return "gpt-5.3-codex-spark"
+ }
if strings.Contains(normalized, "gpt-5.3-codex") || strings.Contains(normalized, "gpt 5.3 codex") {
return "gpt-5.3-codex"
}
if strings.Contains(normalized, "gpt-5.3") || strings.Contains(normalized, "gpt 5.3") {
return "gpt-5.3-codex"
}
- if strings.Contains(normalized, "gpt-5.1-codex-max") || strings.Contains(normalized, "gpt 5.1 codex max") {
- return "gpt-5.1-codex-max"
- }
- if strings.Contains(normalized, "gpt-5.1-codex-mini") || strings.Contains(normalized, "gpt 5.1 codex mini") {
- return "gpt-5.1-codex-mini"
- }
- if strings.Contains(normalized, "codex-mini-latest") ||
- strings.Contains(normalized, "gpt-5-codex-mini") ||
- strings.Contains(normalized, "gpt 5 codex mini") {
- return "codex-mini-latest"
- }
- if strings.Contains(normalized, "gpt-5.1-codex") || strings.Contains(normalized, "gpt 5.1 codex") {
- return "gpt-5.1-codex"
- }
- if strings.Contains(normalized, "gpt-5.1") || strings.Contains(normalized, "gpt 5.1") {
- return "gpt-5.1"
- }
if strings.Contains(normalized, "codex") {
- return "gpt-5.1-codex"
+ return "gpt-5.3-codex"
}
if strings.Contains(normalized, "gpt-5") || strings.Contains(normalized, "gpt 5") {
- return "gpt-5.1"
+ return "gpt-5.4"
}
- return "gpt-5.1"
+ return "gpt-5.4"
}
func normalizeOpenAIModelForUpstream(account *Account, model string) string {
diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go
index 993ade07..22264f5e 100644
--- a/backend/internal/service/openai_codex_transform_test.go
+++ b/backend/internal/service/openai_codex_transform_test.go
@@ -240,15 +240,13 @@ func TestNormalizeCodexModel_Gpt53(t *testing.T) {
"gpt 5.4": "gpt-5.4",
"gpt-5.4-mini": "gpt-5.4-mini",
"gpt 5.4 mini": "gpt-5.4-mini",
- "gpt-5.4-nano": "gpt-5.4-nano",
- "gpt 5.4 nano": "gpt-5.4-nano",
"gpt-5.3": "gpt-5.3-codex",
"gpt-5.3-codex": "gpt-5.3-codex",
"gpt-5.3-codex-xhigh": "gpt-5.3-codex",
- "gpt-5.3-codex-spark": "gpt-5.3-codex",
- "gpt 5.3 codex spark": "gpt-5.3-codex",
- "gpt-5.3-codex-spark-high": "gpt-5.3-codex",
- "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex",
+ "gpt-5.3-codex-spark": "gpt-5.3-codex-spark",
+ "gpt 5.3 codex spark": "gpt-5.3-codex-spark",
+ "gpt-5.3-codex-spark-high": "gpt-5.3-codex-spark",
+ "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex-spark",
"gpt 5.3 codex": "gpt-5.3-codex",
}
@@ -257,6 +255,26 @@ func TestNormalizeCodexModel_Gpt53(t *testing.T) {
}
}
+func TestNormalizeCodexModel_RemovedModelsFallbackToSupportedTargets(t *testing.T) {
+ cases := map[string]string{
+ "": "gpt-5.4",
+ "gpt-5": "gpt-5.4",
+ "gpt-5-mini": "gpt-5.4",
+ "gpt-5-nano": "gpt-5.4",
+ "gpt-5.1": "gpt-5.4",
+ "gpt-5.1-codex": "gpt-5.3-codex",
+ "gpt-5.1-codex-max": "gpt-5.3-codex",
+ "gpt-5.1-codex-mini": "gpt-5.3-codex",
+ "gpt-5.2-codex": "gpt-5.2",
+ "codex-mini-latest": "gpt-5.3-codex",
+ "gpt-5-codex": "gpt-5.3-codex",
+ }
+
+ for input, expected := range cases {
+ require.Equal(t, expected, normalizeCodexModel(input))
+ }
+}
+
func TestApplyCodexOAuthTransform_PreservesBareSparkModel(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.3-codex-spark",
diff --git a/backend/internal/service/openai_compat_prompt_cache_key.go b/backend/internal/service/openai_compat_prompt_cache_key.go
index 88e16a4d..fcd27f19 100644
--- a/backend/internal/service/openai_compat_prompt_cache_key.go
+++ b/backend/internal/service/openai_compat_prompt_cache_key.go
@@ -10,8 +10,14 @@ import (
const compatPromptCacheKeyPrefix = "compat_cc_"
func shouldAutoInjectPromptCacheKeyForCompat(model string) bool {
- switch normalizeCodexModel(strings.TrimSpace(model)) {
- case "gpt-5.4", "gpt-5.3-codex":
+ trimmed := strings.TrimSpace(strings.ToLower(model))
+ // 仅对 Codex OAuth 路径支持的 GPT-5 族开启自动注入,避免 normalizeCodexModel
+ // 的默认兜底把任意模型(如 gpt-4o、claude-*)误判为 gpt-5.4。
+ if !strings.Contains(trimmed, "gpt-5") && !strings.Contains(trimmed, "codex") {
+ return false
+ }
+ switch normalizeCodexModel(trimmed) {
+ case "gpt-5.4", "gpt-5.3-codex", "gpt-5.3-codex-spark":
return true
default:
return false
diff --git a/backend/internal/service/openai_model_mapping_test.go b/backend/internal/service/openai_model_mapping_test.go
index cda7e369..35e7c250 100644
--- a/backend/internal/service/openai_model_mapping_test.go
+++ b/backend/internal/service/openai_model_mapping_test.go
@@ -69,14 +69,14 @@ func TestResolveOpenAIForwardModel(t *testing.T) {
}
}
-func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt51(t *testing.T) {
+func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt54(t *testing.T) {
account := &Account{
Credentials: map[string]any{},
}
withoutDefault := normalizeCodexModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", ""))
- if withoutDefault != "gpt-5.1" {
- t.Fatalf("normalizeCodexModel(...) = %q, want %q", withoutDefault, "gpt-5.1")
+ if withoutDefault != "gpt-5.4" {
+ t.Fatalf("normalizeCodexModel(...) = %q, want %q", withoutDefault, "gpt-5.4")
}
withDefault := normalizeCodexModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4"))
@@ -87,9 +87,9 @@ func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt51(t *
func TestNormalizeCodexModel(t *testing.T) {
cases := map[string]string{
- "gpt-5.3-codex-spark": "gpt-5.3-codex",
- "gpt-5.3-codex-spark-high": "gpt-5.3-codex",
- "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex",
+ "gpt-5.3-codex-spark": "gpt-5.3-codex-spark",
+ "gpt-5.3-codex-spark-high": "gpt-5.3-codex-spark",
+ "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex-spark",
"gpt-5.3": "gpt-5.3-codex",
}
@@ -111,7 +111,7 @@ func TestNormalizeOpenAIModelForUpstream(t *testing.T) {
name: "oauth keeps codex normalization behavior",
account: &Account{Type: AccountTypeOAuth},
model: "gemini-3-flash-preview",
- want: "gpt-5.1",
+ want: "gpt-5.4",
},
{
name: "apikey preserves custom compatible model",
diff --git a/frontend/src/components/keys/UseKeyModal.vue b/frontend/src/components/keys/UseKeyModal.vue
index 7770e658..b3679107 100644
--- a/frontend/src/components/keys/UseKeyModal.vue
+++ b/frontend/src/components/keys/UseKeyModal.vue
@@ -617,66 +617,6 @@ function generateOpenCodeConfig(platform: string, baseUrl: string, apiKey: strin
}
}
const openaiModels = {
- 'gpt-5-codex': {
- name: 'GPT-5 Codex',
- limit: {
- context: 400000,
- output: 128000
- },
- options: {
- store: false
- },
- variants: {
- low: {},
- medium: {},
- high: {}
- }
- },
- 'gpt-5.1-codex': {
- name: 'GPT-5.1 Codex',
- limit: {
- context: 400000,
- output: 128000
- },
- options: {
- store: false
- },
- variants: {
- low: {},
- medium: {},
- high: {}
- }
- },
- 'gpt-5.1-codex-max': {
- name: 'GPT-5.1 Codex Max',
- limit: {
- context: 400000,
- output: 128000
- },
- options: {
- store: false
- },
- variants: {
- low: {},
- medium: {},
- high: {}
- }
- },
- 'gpt-5.1-codex-mini': {
- name: 'GPT-5.1 Codex Mini',
- limit: {
- context: 400000,
- output: 128000
- },
- options: {
- store: false
- },
- variants: {
- low: {},
- medium: {},
- high: {}
- }
- },
'gpt-5.2': {
name: 'GPT-5.2',
limit: {
@@ -725,22 +665,6 @@ function generateOpenCodeConfig(platform: string, baseUrl: string, apiKey: strin
xhigh: {}
}
},
- 'gpt-5.4-nano': {
- name: 'GPT-5.4 Nano',
- limit: {
- context: 400000,
- output: 128000
- },
- options: {
- store: false
- },
- variants: {
- low: {},
- medium: {},
- high: {},
- xhigh: {}
- }
- },
'gpt-5.3-codex-spark': {
name: 'GPT-5.3 Codex Spark',
limit: {
@@ -773,22 +697,6 @@ function generateOpenCodeConfig(platform: string, baseUrl: string, apiKey: strin
xhigh: {}
}
},
- 'gpt-5.2-codex': {
- name: 'GPT-5.2 Codex',
- limit: {
- context: 400000,
- output: 128000
- },
- options: {
- store: false
- },
- variants: {
- low: {},
- medium: {},
- high: {},
- xhigh: {}
- }
- },
'codex-mini-latest': {
name: 'Codex Mini',
limit: {
diff --git a/frontend/src/components/keys/__tests__/UseKeyModal.spec.ts b/frontend/src/components/keys/__tests__/UseKeyModal.spec.ts
index 98b5dede..f7db586a 100644
--- a/frontend/src/components/keys/__tests__/UseKeyModal.spec.ts
+++ b/frontend/src/components/keys/__tests__/UseKeyModal.spec.ts
@@ -17,7 +17,7 @@ vi.mock('@/composables/useClipboard', () => ({
import UseKeyModal from '../UseKeyModal.vue'
describe('UseKeyModal', () => {
- it('renders updated GPT-5.4 mini/nano names in OpenCode config', async () => {
+ it('renders GPT-5.4 mini entry in OpenCode config', async () => {
const wrapper = mount(UseKeyModal, {
props: {
show: true,
@@ -48,6 +48,6 @@ describe('UseKeyModal', () => {
const codeBlock = wrapper.find('pre code')
expect(codeBlock.exists()).toBe(true)
expect(codeBlock.text()).toContain('"name": "GPT-5.4 Mini"')
- expect(codeBlock.text()).toContain('"name": "GPT-5.4 Nano"')
+ expect(codeBlock.text()).not.toContain('"name": "GPT-5.4 Nano"')
})
})
diff --git a/frontend/src/composables/__tests__/useModelWhitelist.spec.ts b/frontend/src/composables/__tests__/useModelWhitelist.spec.ts
index 4061be4d..d35e3b12 100644
--- a/frontend/src/composables/__tests__/useModelWhitelist.spec.ts
+++ b/frontend/src/composables/__tests__/useModelWhitelist.spec.ts
@@ -12,10 +12,20 @@ describe('useModelWhitelist', () => {
expect(models).toContain('gpt-5.4')
expect(models).toContain('gpt-5.4-mini')
- expect(models).toContain('gpt-5.4-nano')
expect(models).toContain('gpt-5.4-2026-03-05')
})
+ it('openai 模型列表不再暴露已下线的 ChatGPT 登录 Codex 模型', () => {
+ const models = getModelsByPlatform('openai')
+
+ expect(models).not.toContain('gpt-5')
+ expect(models).not.toContain('gpt-5.1')
+ expect(models).not.toContain('gpt-5.1-codex')
+ expect(models).not.toContain('gpt-5.1-codex-max')
+ expect(models).not.toContain('gpt-5.1-codex-mini')
+ expect(models).not.toContain('gpt-5.2-codex')
+ })
+
it('antigravity 模型列表包含图片模型兼容项', () => {
const models = getModelsByPlatform('antigravity')
@@ -55,12 +65,11 @@ describe('useModelWhitelist', () => {
})
})
- it('whitelist keeps GPT-5.4 mini and nano exact mappings', () => {
- const mapping = buildModelMappingObject('whitelist', ['gpt-5.4-mini', 'gpt-5.4-nano'], [])
+ it('whitelist keeps GPT-5.4 mini exact mappings', () => {
+ const mapping = buildModelMappingObject('whitelist', ['gpt-5.4-mini'], [])
expect(mapping).toEqual({
- 'gpt-5.4-mini': 'gpt-5.4-mini',
- 'gpt-5.4-nano': 'gpt-5.4-nano'
+ 'gpt-5.4-mini': 'gpt-5.4-mini'
})
})
})
diff --git a/frontend/src/composables/useModelWhitelist.ts b/frontend/src/composables/useModelWhitelist.ts
index a282ae7d..ddd7af48 100644
--- a/frontend/src/composables/useModelWhitelist.ts
+++ b/frontend/src/composables/useModelWhitelist.ts
@@ -13,19 +13,11 @@ const openaiModels = [
'o1', 'o1-preview', 'o1-mini', 'o1-pro',
'o3', 'o3-mini', 'o3-pro',
'o4-mini',
- // GPT-5 系列(同步后端定价文件)
- 'gpt-5', 'gpt-5-2025-08-07', 'gpt-5-chat', 'gpt-5-chat-latest',
- 'gpt-5-codex', 'gpt-5.3-codex-spark', 'gpt-5-pro', 'gpt-5-pro-2025-10-06',
- 'gpt-5-mini', 'gpt-5-mini-2025-08-07',
- 'gpt-5-nano', 'gpt-5-nano-2025-08-07',
- // GPT-5.1 系列
- 'gpt-5.1', 'gpt-5.1-2025-11-13', 'gpt-5.1-chat-latest',
- 'gpt-5.1-codex', 'gpt-5.1-codex-max', 'gpt-5.1-codex-mini',
// GPT-5.2 系列
'gpt-5.2', 'gpt-5.2-2025-12-11', 'gpt-5.2-chat-latest',
- 'gpt-5.2-codex', 'gpt-5.2-pro', 'gpt-5.2-pro-2025-12-11',
+ 'gpt-5.2-pro', 'gpt-5.2-pro-2025-12-11',
// GPT-5.4 系列
- 'gpt-5.4', 'gpt-5.4-mini', 'gpt-5.4-nano', 'gpt-5.4-2026-03-05',
+ 'gpt-5.4', 'gpt-5.4-mini', 'gpt-5.4-2026-03-05',
// GPT-5.3 系列
'gpt-5.3-codex', 'gpt-5.3-codex-spark',
'chatgpt-4o-latest',
@@ -264,12 +256,9 @@ const openaiPresetMappings = [
{ label: 'GPT-4.1', from: 'gpt-4.1', to: 'gpt-4.1', color: 'bg-indigo-100 text-indigo-700 hover:bg-indigo-200 dark:bg-indigo-900/30 dark:text-indigo-400' },
{ label: 'o1', from: 'o1', to: 'o1', color: 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' },
{ label: 'o3', from: 'o3', to: 'o3', color: 'bg-emerald-100 text-emerald-700 hover:bg-emerald-200 dark:bg-emerald-900/30 dark:text-emerald-400' },
- { label: 'GPT-5', from: 'gpt-5', to: 'gpt-5', color: 'bg-amber-100 text-amber-700 hover:bg-amber-200 dark:bg-amber-900/30 dark:text-amber-400' },
{ label: 'GPT-5.3 Codex Spark', from: 'gpt-5.3-codex-spark', to: 'gpt-5.3-codex-spark', color: 'bg-teal-100 text-teal-700 hover:bg-teal-200 dark:bg-teal-900/30 dark:text-teal-400' },
- { label: 'GPT-5.1', from: 'gpt-5.1', to: 'gpt-5.1', color: 'bg-orange-100 text-orange-700 hover:bg-orange-200 dark:bg-orange-900/30 dark:text-orange-400' },
{ label: 'GPT-5.2', from: 'gpt-5.2', to: 'gpt-5.2', color: 'bg-red-100 text-red-700 hover:bg-red-200 dark:bg-red-900/30 dark:text-red-400' },
{ label: 'GPT-5.4', from: 'gpt-5.4', to: 'gpt-5.4', color: 'bg-rose-100 text-rose-700 hover:bg-rose-200 dark:bg-rose-900/30 dark:text-rose-400' },
- { label: 'GPT-5.1 Codex', from: 'gpt-5.1-codex', to: 'gpt-5.1-codex', color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400' },
{ label: 'Haiku→5.4', from: 'claude-haiku-4-5-20251001', to: 'gpt-5.4', color: 'bg-emerald-100 text-emerald-700 hover:bg-emerald-200 dark:bg-emerald-900/30 dark:text-emerald-400' },
{ label: 'Opus→5.4', from: 'claude-opus-4-6', to: 'gpt-5.4', color: 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' },
{ label: 'Sonnet→5.4', from: 'claude-sonnet-4-6', to: 'gpt-5.4', color: 'bg-blue-100 text-blue-700 hover:bg-blue-200 dark:bg-blue-900/30 dark:text-blue-400' }
--
GitLab
From 3bd3027251dd5e008e60ba7299ef63dbc8f8a32c Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Mon, 20 Apr 2026 22:05:33 +0800
Subject: [PATCH 062/261] feat: expose auth identity migration reports
---
.../admin/admin_basic_handlers_test.go | 12 ++
.../handler/admin/admin_service_stub_test.go | 34 +++++
.../internal/handler/admin/user_handler.go | 25 ++++
backend/internal/server/routes/admin.go | 2 +
backend/internal/service/admin_service.go | 135 ++++++++++++++++++
..._service_identity_migration_report_test.go | 92 ++++++++++++
6 files changed, 300 insertions(+)
create mode 100644 backend/internal/service/admin_service_identity_migration_report_test.go
diff --git a/backend/internal/handler/admin/admin_basic_handlers_test.go b/backend/internal/handler/admin/admin_basic_handlers_test.go
index cba3ae21..ff9eec7e 100644
--- a/backend/internal/handler/admin/admin_basic_handlers_test.go
+++ b/backend/internal/handler/admin/admin_basic_handlers_test.go
@@ -22,6 +22,8 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
redeemHandler := NewRedeemHandler(adminSvc, nil)
router.GET("/api/v1/admin/users", userHandler.List)
+ router.GET("/api/v1/admin/users/auth-identity-migration-reports/summary", userHandler.GetAuthIdentityMigrationReportSummary)
+ router.GET("/api/v1/admin/users/auth-identity-migration-reports", userHandler.ListAuthIdentityMigrationReports)
router.GET("/api/v1/admin/users/:id", userHandler.GetByID)
router.POST("/api/v1/admin/users", userHandler.Create)
router.PUT("/api/v1/admin/users/:id", userHandler.Update)
@@ -70,6 +72,16 @@ func TestUserHandlerEndpoints(t *testing.T) {
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/auth-identity-migration-reports/summary", nil)
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/auth-identity-migration-reports?report_type=oidc_synthetic_email_requires_manual_recovery", nil)
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusOK, rec.Code)
+
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/1", nil)
router.ServeHTTP(rec, req)
diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go
index 6d1ef1b6..681c25c6 100644
--- a/backend/internal/handler/admin/admin_service_stub_test.go
+++ b/backend/internal/handler/admin/admin_service_stub_test.go
@@ -17,6 +17,7 @@ type stubAdminService struct {
proxies []service.Proxy
proxyCounts []service.ProxyWithAccountCount
redeems []service.RedeemCode
+ migrationReports []service.AuthIdentityMigrationReport
createdAccounts []*service.CreateAccountInput
createdProxies []*service.CreateProxyInput
updatedProxyIDs []int64
@@ -123,6 +124,15 @@ func newStubAdminService() *stubAdminService {
proxies: []service.Proxy{proxy},
proxyCounts: []service.ProxyWithAccountCount{{Proxy: proxy, AccountCount: 1}},
redeems: []service.RedeemCode{redeem},
+ migrationReports: []service.AuthIdentityMigrationReport{
+ {
+ ID: 1,
+ ReportType: "oidc_synthetic_email_requires_manual_recovery",
+ ReportKey: "u-1",
+ Details: map[string]any{"user_id": 1},
+ CreatedAt: now,
+ },
+ },
}
}
@@ -167,6 +177,30 @@ func (s *stubAdminService) GetUserUsageStats(ctx context.Context, userID int64,
return map[string]any{"user_id": userID}, nil
}
+func (s *stubAdminService) ListAuthIdentityMigrationReports(ctx context.Context, reportType string, page, pageSize int) ([]service.AuthIdentityMigrationReport, int64, error) {
+ if reportType == "" {
+ return s.migrationReports, int64(len(s.migrationReports)), nil
+ }
+ filtered := make([]service.AuthIdentityMigrationReport, 0, len(s.migrationReports))
+ for _, report := range s.migrationReports {
+ if strings.EqualFold(report.ReportType, reportType) {
+ filtered = append(filtered, report)
+ }
+ }
+ return filtered, int64(len(filtered)), nil
+}
+
+func (s *stubAdminService) GetAuthIdentityMigrationReportSummary(ctx context.Context) (*service.AuthIdentityMigrationReportSummary, error) {
+ summary := &service.AuthIdentityMigrationReportSummary{
+ ByType: map[string]int64{},
+ }
+ for _, report := range s.migrationReports {
+ summary.Total++
+ summary.ByType[report.ReportType]++
+ }
+ return summary, nil
+}
+
func (s *stubAdminService) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]service.Group, int64, error) {
return s.groups, int64(len(s.groups)), nil
}
diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go
index 1453bd07..ee3fbb1e 100644
--- a/backend/internal/handler/admin/user_handler.go
+++ b/backend/internal/handler/admin/user_handler.go
@@ -172,6 +172,31 @@ func (h *UserHandler) GetByID(c *gin.Context) {
response.Success(c, dto.UserFromServiceAdmin(user))
}
+// GetAuthIdentityMigrationReportSummary returns aggregate migration report counts.
+// GET /api/v1/admin/users/auth-identity-migration-reports/summary
+func (h *UserHandler) GetAuthIdentityMigrationReportSummary(c *gin.Context) {
+ summary, err := h.adminService.GetAuthIdentityMigrationReportSummary(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, summary)
+}
+
+// ListAuthIdentityMigrationReports returns paginated auth identity migration reports.
+// GET /api/v1/admin/users/auth-identity-migration-reports
+func (h *UserHandler) ListAuthIdentityMigrationReports(c *gin.Context) {
+ page, pageSize := response.ParsePagination(c)
+ reportType := strings.TrimSpace(c.Query("report_type"))
+
+ reports, total, err := h.adminService.ListAuthIdentityMigrationReports(c.Request.Context(), reportType, page, pageSize)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Paginated(c, reports, total, page, pageSize)
+}
+
// Create handles creating a new user
// POST /api/v1/admin/users
func (h *UserHandler) Create(c *gin.Context) {
diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go
index 9af0fd8e..0b5aaf09 100644
--- a/backend/internal/server/routes/admin.go
+++ b/backend/internal/server/routes/admin.go
@@ -210,6 +210,8 @@ func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
users := admin.Group("/users")
{
+ users.GET("/auth-identity-migration-reports/summary", h.Admin.User.GetAuthIdentityMigrationReportSummary)
+ users.GET("/auth-identity-migration-reports", h.Admin.User.ListAuthIdentityMigrationReports)
users.GET("", h.Admin.User.List)
users.GET("/:id", h.Admin.User.GetByID)
users.POST("", h.Admin.User.Create)
diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go
index 7c26a47c..972681a5 100644
--- a/backend/internal/service/admin_service.go
+++ b/backend/internal/service/admin_service.go
@@ -2,6 +2,8 @@ package service
import (
"context"
+ "database/sql"
+ "encoding/json"
"errors"
"fmt"
"io"
@@ -16,6 +18,8 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/util/httputil"
+
+ entsql "entgo.io/ent/dialect/sql"
)
// AdminService interface defines admin management operations
@@ -33,6 +37,8 @@ type AdminService interface {
// codeType is optional - pass empty string to return all types.
// Also returns totalRecharged (sum of all positive balance top-ups).
GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error)
+ ListAuthIdentityMigrationReports(ctx context.Context, reportType string, page, pageSize int) ([]AuthIdentityMigrationReport, int64, error)
+ GetAuthIdentityMigrationReportSummary(ctx context.Context) (*AuthIdentityMigrationReportSummary, error)
// Group management
ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]Group, int64, error)
@@ -127,6 +133,19 @@ type UpdateUserInput struct {
GroupRates map[int64]*float64
}
+type AuthIdentityMigrationReport struct {
+ ID int64 `json:"id"`
+ ReportType string `json:"report_type"`
+ ReportKey string `json:"report_key"`
+ Details map[string]any `json:"details"`
+ CreatedAt time.Time `json:"created_at"`
+}
+
+type AuthIdentityMigrationReportSummary struct {
+ Total int64 `json:"total"`
+ ByType map[string]int64 `json:"by_type"`
+}
+
type CreateGroupInput struct {
Name string
Description string
@@ -788,6 +807,122 @@ func (s *adminServiceImpl) GetUserBalanceHistory(ctx context.Context, userID int
return codes, result.Total, totalRecharged, nil
}
+func (s *adminServiceImpl) ListAuthIdentityMigrationReports(ctx context.Context, reportType string, page, pageSize int) ([]AuthIdentityMigrationReport, int64, error) {
+ db, err := s.adminSQLDB()
+ if err != nil {
+ return nil, 0, err
+ }
+
+ reportType = strings.TrimSpace(reportType)
+ if page <= 0 {
+ page = 1
+ }
+ if pageSize <= 0 {
+ pageSize = 20
+ }
+ offset := (page - 1) * pageSize
+
+ var total int64
+ if err := db.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE ($1 = '' OR report_type = $1)`,
+ reportType,
+ ).Scan(&total); err != nil {
+ return nil, 0, err
+ }
+
+ rows, err := db.QueryContext(ctx, `
+SELECT id, report_type, report_key, details, created_at
+FROM auth_identity_migration_reports
+WHERE ($1 = '' OR report_type = $1)
+ORDER BY created_at DESC, id DESC
+LIMIT $2 OFFSET $3`,
+ reportType,
+ pageSize,
+ offset,
+ )
+ if err != nil {
+ return nil, 0, err
+ }
+ defer func() { _ = rows.Close() }()
+
+ reports := make([]AuthIdentityMigrationReport, 0)
+ for rows.Next() {
+ report, scanErr := scanAuthIdentityMigrationReport(rows)
+ if scanErr != nil {
+ return nil, 0, scanErr
+ }
+ reports = append(reports, report)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, 0, err
+ }
+ return reports, total, nil
+}
+
+func (s *adminServiceImpl) GetAuthIdentityMigrationReportSummary(ctx context.Context) (*AuthIdentityMigrationReportSummary, error) {
+ db, err := s.adminSQLDB()
+ if err != nil {
+ return nil, err
+ }
+
+ rows, err := db.QueryContext(ctx, `
+SELECT report_type, COUNT(*)
+FROM auth_identity_migration_reports
+GROUP BY report_type
+ORDER BY report_type ASC`)
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = rows.Close() }()
+
+ summary := &AuthIdentityMigrationReportSummary{
+ ByType: make(map[string]int64),
+ }
+ for rows.Next() {
+ var reportType string
+ var count int64
+ if err := rows.Scan(&reportType, &count); err != nil {
+ return nil, err
+ }
+ summary.ByType[reportType] = count
+ summary.Total += count
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return summary, nil
+}
+
+func (s *adminServiceImpl) adminSQLDB() (*sql.DB, error) {
+ if s == nil || s.entClient == nil {
+ return nil, infraerrors.ServiceUnavailable("ADMIN_SQL_NOT_READY", "admin sql access is not ready")
+ }
+ driver, ok := s.entClient.Driver().(*entsql.Driver)
+ if !ok || driver.DB() == nil {
+ return nil, infraerrors.ServiceUnavailable("ADMIN_SQL_NOT_READY", "admin sql access is not ready")
+ }
+ return driver.DB(), nil
+}
+
+func scanAuthIdentityMigrationReport(scanner interface{ Scan(dest ...any) error }) (AuthIdentityMigrationReport, error) {
+ var (
+ report AuthIdentityMigrationReport
+ details []byte
+ )
+ if err := scanner.Scan(&report.ID, &report.ReportType, &report.ReportKey, &details, &report.CreatedAt); err != nil {
+ return AuthIdentityMigrationReport{}, err
+ }
+ report.Details = map[string]any{}
+ if len(details) > 0 {
+ if err := json.Unmarshal(details, &report.Details); err != nil {
+ return AuthIdentityMigrationReport{}, err
+ }
+ }
+ return report, nil
+}
+
// Group management implementations
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]Group, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
diff --git a/backend/internal/service/admin_service_identity_migration_report_test.go b/backend/internal/service/admin_service_identity_migration_report_test.go
new file mode 100644
index 00000000..75ca3e5a
--- /dev/null
+++ b/backend/internal/service/admin_service_identity_migration_report_test.go
@@ -0,0 +1,92 @@
+package service
+
+import (
+ "context"
+ "database/sql"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
+ "github.com/stretchr/testify/require"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
+)
+
+func newAdminServiceMigrationReportTestClient(t *testing.T) *dbent.Client {
+ t.Helper()
+
+ db, err := sql.Open("sqlite", "file:admin_service_migration_reports?mode=memory&cache=shared&_fk=1")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+
+ _, err = db.Exec(`CREATE TABLE auth_identity_migration_reports (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ report_type TEXT NOT NULL,
+ report_key TEXT NOT NULL,
+ details TEXT NOT NULL DEFAULT '{}',
+ created_at DATETIME NOT NULL
+ )`)
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+ t.Cleanup(func() { _ = client.Close() })
+ return client
+}
+
+func TestAdminServiceListAuthIdentityMigrationReports(t *testing.T) {
+ client := newAdminServiceMigrationReportTestClient(t)
+ driver, ok := client.Driver().(*entsql.Driver)
+ require.True(t, ok)
+
+ now := time.Now().UTC()
+ _, err := driver.DB().ExecContext(context.Background(), `
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details, created_at)
+VALUES
+ ($1, $2, $3, $4),
+ ($5, $6, $7, $8)`,
+ "oidc_synthetic_email_requires_manual_recovery", "u-1", `{"user_id":1}`, now,
+ "wechat_provider_key_conflict", "u-2", `{"user_id":2}`, now.Add(-time.Minute),
+ )
+ require.NoError(t, err)
+
+ svc := &adminServiceImpl{entClient: client}
+ reports, total, err := svc.ListAuthIdentityMigrationReports(context.Background(), "oidc_synthetic_email_requires_manual_recovery", 1, 20)
+ require.NoError(t, err)
+ require.Equal(t, int64(1), total)
+ require.Len(t, reports, 1)
+ require.Equal(t, "oidc_synthetic_email_requires_manual_recovery", reports[0].ReportType)
+ require.Equal(t, float64(1), reports[0].Details["user_id"])
+}
+
+func TestAdminServiceGetAuthIdentityMigrationReportSummary(t *testing.T) {
+ client := newAdminServiceMigrationReportTestClient(t)
+ driver, ok := client.Driver().(*entsql.Driver)
+ require.True(t, ok)
+
+ now := time.Now().UTC()
+ _, err := driver.DB().ExecContext(context.Background(), `
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details, created_at)
+VALUES
+ ($1, $2, $3, $4),
+ ($5, $6, $7, $8),
+ ($9, $10, $11, $12)`,
+ "oidc_synthetic_email_requires_manual_recovery", "u-1", `{"user_id":1}`, now,
+ "wechat_provider_key_conflict", "u-2", `{"user_id":2}`, now.Add(-time.Minute),
+ "wechat_provider_key_conflict", "u-3", `{"user_id":3}`, now.Add(-2*time.Minute),
+ )
+ require.NoError(t, err)
+
+ svc := &adminServiceImpl{entClient: client}
+ summary, err := svc.GetAuthIdentityMigrationReportSummary(context.Background())
+ require.NoError(t, err)
+ require.Equal(t, int64(3), summary.Total)
+ require.Equal(t, int64(1), summary.ByType["oidc_synthetic_email_requires_manual_recovery"])
+ require.Equal(t, int64(2), summary.ByType["wechat_provider_key_conflict"])
+}
--
GitLab
From 452e55a53c7a22c3e4a0da2421d19cfb6819de96 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Mon, 20 Apr 2026 22:22:14 +0800
Subject: [PATCH 063/261] feat: add admin auth identity repair binding
---
.../admin/admin_basic_handlers_test.go | 48 +++-
.../handler/admin/admin_service_stub_test.go | 48 ++++
.../internal/handler/admin/user_handler.go | 55 ++++
backend/internal/server/routes/admin.go | 1 +
backend/internal/service/admin_service.go | 262 ++++++++++++++++++
...dmin_service_auth_identity_binding_test.go | 215 ++++++++++++++
6 files changed, 628 insertions(+), 1 deletion(-)
create mode 100644 backend/internal/service/admin_service_auth_identity_binding_test.go
diff --git a/backend/internal/handler/admin/admin_basic_handlers_test.go b/backend/internal/handler/admin/admin_basic_handlers_test.go
index ff9eec7e..57620005 100644
--- a/backend/internal/handler/admin/admin_basic_handlers_test.go
+++ b/backend/internal/handler/admin/admin_basic_handlers_test.go
@@ -25,6 +25,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
router.GET("/api/v1/admin/users/auth-identity-migration-reports/summary", userHandler.GetAuthIdentityMigrationReportSummary)
router.GET("/api/v1/admin/users/auth-identity-migration-reports", userHandler.ListAuthIdentityMigrationReports)
router.GET("/api/v1/admin/users/:id", userHandler.GetByID)
+ router.POST("/api/v1/admin/users/:id/auth-identities", userHandler.BindAuthIdentity)
router.POST("/api/v1/admin/users", userHandler.Create)
router.PUT("/api/v1/admin/users/:id", userHandler.Update)
router.DELETE("/api/v1/admin/users/:id", userHandler.Delete)
@@ -87,8 +88,26 @@ func TestUserHandlerEndpoints(t *testing.T) {
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
+ bindBody := map[string]any{
+ "provider_type": "wechat",
+ "provider_key": "wechat-main",
+ "provider_subject": "union-123",
+ "metadata": map[string]any{"source": "admin-repair"},
+ "channel": map[string]any{
+ "channel": "open",
+ "channel_app_id": "wx-open",
+ "channel_subject": "openid-123",
+ },
+ }
+ body, _ := json.Marshal(bindBody)
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users/1/auth-identities", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusOK, rec.Code)
+
createBody := map[string]any{"email": "new@example.com", "password": "pass123", "balance": 1, "concurrency": 2}
- body, _ := json.Marshal(createBody)
+ body, _ = json.Marshal(createBody)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
@@ -125,6 +144,33 @@ func TestUserHandlerEndpoints(t *testing.T) {
require.Equal(t, http.StatusOK, rec.Code)
}
+func TestUserHandlerBindAuthIdentityMapsRequest(t *testing.T) {
+ router, adminSvc := setupAdminRouter()
+
+ body, err := json.Marshal(map[string]any{
+ "provider_type": "oidc",
+ "provider_key": "https://issuer.example",
+ "provider_subject": "subject-123",
+ "issuer": "https://issuer.example",
+ "metadata": map[string]any{"report_id": 12},
+ })
+ require.NoError(t, err)
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/users/9/auth-identities", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, int64(9), adminSvc.boundAuthIdentityFor)
+ require.NotNil(t, adminSvc.boundAuthIdentity)
+ require.Equal(t, "oidc", adminSvc.boundAuthIdentity.ProviderType)
+ require.Equal(t, "https://issuer.example", adminSvc.boundAuthIdentity.ProviderKey)
+ require.Equal(t, "subject-123", adminSvc.boundAuthIdentity.ProviderSubject)
+ require.Nil(t, adminSvc.boundAuthIdentity.Channel)
+ require.Equal(t, float64(12), adminSvc.boundAuthIdentity.Metadata["report_id"])
+}
+
func TestGroupHandlerEndpoints(t *testing.T) {
router, _ := setupAdminRouter()
diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go
index 681c25c6..c8c7a247 100644
--- a/backend/internal/handler/admin/admin_service_stub_test.go
+++ b/backend/internal/handler/admin/admin_service_stub_test.go
@@ -18,6 +18,8 @@ type stubAdminService struct {
proxyCounts []service.ProxyWithAccountCount
redeems []service.RedeemCode
migrationReports []service.AuthIdentityMigrationReport
+ boundAuthIdentity *service.AdminBindAuthIdentityInput
+ boundAuthIdentityFor int64
createdAccounts []*service.CreateAccountInput
createdProxies []*service.CreateProxyInput
updatedProxyIDs []int64
@@ -201,6 +203,52 @@ func (s *stubAdminService) GetAuthIdentityMigrationReportSummary(ctx context.Con
return summary, nil
}
+func (s *stubAdminService) BindUserAuthIdentity(ctx context.Context, userID int64, input service.AdminBindAuthIdentityInput) (*service.AdminBoundAuthIdentity, error) {
+ s.boundAuthIdentityFor = userID
+ copied := input
+ if input.Metadata != nil {
+ copied.Metadata = map[string]any{}
+ for key, value := range input.Metadata {
+ copied.Metadata[key] = value
+ }
+ }
+ if input.Channel != nil {
+ channel := *input.Channel
+ if input.Channel.Metadata != nil {
+ channel.Metadata = map[string]any{}
+ for key, value := range input.Channel.Metadata {
+ channel.Metadata[key] = value
+ }
+ }
+ copied.Channel = &channel
+ }
+ s.boundAuthIdentity = &copied
+
+ now := time.Now().UTC()
+ result := &service.AdminBoundAuthIdentity{
+ UserID: userID,
+ ProviderType: input.ProviderType,
+ ProviderKey: input.ProviderKey,
+ ProviderSubject: input.ProviderSubject,
+ VerifiedAt: &now,
+ Issuer: input.Issuer,
+ Metadata: input.Metadata,
+ CreatedAt: now,
+ UpdatedAt: now,
+ }
+ if input.Channel != nil {
+ result.Channel = &service.AdminBoundAuthIdentityChannel{
+ Channel: input.Channel.Channel,
+ ChannelAppID: input.Channel.ChannelAppID,
+ ChannelSubject: input.Channel.ChannelSubject,
+ Metadata: input.Channel.Metadata,
+ CreatedAt: now,
+ UpdatedAt: now,
+ }
+ }
+ return result, nil
+}
+
func (s *stubAdminService) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]service.Group, int64, error) {
return s.groups, int64(len(s.groups)), nil
}
diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go
index ee3fbb1e..321214af 100644
--- a/backend/internal/handler/admin/user_handler.go
+++ b/backend/internal/handler/admin/user_handler.go
@@ -66,6 +66,22 @@ type UpdateBalanceRequest struct {
Notes string `json:"notes"`
}
+type BindUserAuthIdentityRequest struct {
+ ProviderType string `json:"provider_type"`
+ ProviderKey string `json:"provider_key"`
+ ProviderSubject string `json:"provider_subject"`
+ Issuer *string `json:"issuer"`
+ Metadata map[string]any `json:"metadata"`
+ Channel *BindUserAuthIdentityChannelRequest `json:"channel"`
+}
+
+type BindUserAuthIdentityChannelRequest struct {
+ Channel string `json:"channel"`
+ ChannelAppID string `json:"channel_app_id"`
+ ChannelSubject string `json:"channel_subject"`
+ Metadata map[string]any `json:"metadata"`
+}
+
// List handles listing all users with pagination
// GET /api/v1/admin/users
// Query params:
@@ -197,6 +213,45 @@ func (h *UserHandler) ListAuthIdentityMigrationReports(c *gin.Context) {
response.Paginated(c, reports, total, page, pageSize)
}
+// BindAuthIdentity manually binds a canonical auth identity to a user.
+// POST /api/v1/admin/users/:id/auth-identities
+func (h *UserHandler) BindAuthIdentity(c *gin.Context) {
+ userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
+ if err != nil {
+ response.BadRequest(c, "Invalid user ID")
+ return
+ }
+
+ var req BindUserAuthIdentityRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ input := service.AdminBindAuthIdentityInput{
+ ProviderType: req.ProviderType,
+ ProviderKey: req.ProviderKey,
+ ProviderSubject: req.ProviderSubject,
+ Issuer: req.Issuer,
+ Metadata: req.Metadata,
+ }
+ if req.Channel != nil {
+ input.Channel = &service.AdminBindAuthIdentityChannelInput{
+ Channel: req.Channel.Channel,
+ ChannelAppID: req.Channel.ChannelAppID,
+ ChannelSubject: req.Channel.ChannelSubject,
+ Metadata: req.Channel.Metadata,
+ }
+ }
+
+ result, err := h.adminService.BindUserAuthIdentity(c.Request.Context(), userID, input)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, result)
+}
+
// Create handles creating a new user
// POST /api/v1/admin/users
func (h *UserHandler) Create(c *gin.Context) {
diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go
index 0b5aaf09..c78fba33 100644
--- a/backend/internal/server/routes/admin.go
+++ b/backend/internal/server/routes/admin.go
@@ -214,6 +214,7 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
users.GET("/auth-identity-migration-reports", h.Admin.User.ListAuthIdentityMigrationReports)
users.GET("", h.Admin.User.List)
users.GET("/:id", h.Admin.User.GetByID)
+ users.POST("/:id/auth-identities", h.Admin.User.BindAuthIdentity)
users.POST("", h.Admin.User.Create)
users.PUT("/:id", h.Admin.User.Update)
users.DELETE("/:id", h.Admin.User.Delete)
diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go
index 972681a5..9ff26861 100644
--- a/backend/internal/service/admin_service.go
+++ b/backend/internal/service/admin_service.go
@@ -13,6 +13,8 @@ import (
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
@@ -39,6 +41,7 @@ type AdminService interface {
GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error)
ListAuthIdentityMigrationReports(ctx context.Context, reportType string, page, pageSize int) ([]AuthIdentityMigrationReport, int64, error)
GetAuthIdentityMigrationReportSummary(ctx context.Context) (*AuthIdentityMigrationReportSummary, error)
+ BindUserAuthIdentity(ctx context.Context, userID int64, input AdminBindAuthIdentityInput) (*AdminBoundAuthIdentity, error)
// Group management
ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]Group, int64, error)
@@ -146,6 +149,44 @@ type AuthIdentityMigrationReportSummary struct {
ByType map[string]int64 `json:"by_type"`
}
+type AdminBindAuthIdentityInput struct {
+ ProviderType string
+ ProviderKey string
+ ProviderSubject string
+ Issuer *string
+ Metadata map[string]any
+ Channel *AdminBindAuthIdentityChannelInput
+}
+
+type AdminBindAuthIdentityChannelInput struct {
+ Channel string
+ ChannelAppID string
+ ChannelSubject string
+ Metadata map[string]any
+}
+
+type AdminBoundAuthIdentity struct {
+ UserID int64 `json:"user_id"`
+ ProviderType string `json:"provider_type"`
+ ProviderKey string `json:"provider_key"`
+ ProviderSubject string `json:"provider_subject"`
+ VerifiedAt *time.Time `json:"verified_at,omitempty"`
+ Issuer *string `json:"issuer,omitempty"`
+ Metadata map[string]any `json:"metadata"`
+ CreatedAt time.Time `json:"created_at"`
+ UpdatedAt time.Time `json:"updated_at"`
+ Channel *AdminBoundAuthIdentityChannel `json:"channel,omitempty"`
+}
+
+type AdminBoundAuthIdentityChannel struct {
+ Channel string `json:"channel"`
+ ChannelAppID string `json:"channel_app_id"`
+ ChannelSubject string `json:"channel_subject"`
+ Metadata map[string]any `json:"metadata"`
+ CreatedAt time.Time `json:"created_at"`
+ UpdatedAt time.Time `json:"updated_at"`
+}
+
type CreateGroupInput struct {
Name string
Description string
@@ -895,6 +936,143 @@ ORDER BY report_type ASC`)
return summary, nil
}
+func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int64, input AdminBindAuthIdentityInput) (*AdminBoundAuthIdentity, error) {
+ if userID <= 0 {
+ return nil, infraerrors.BadRequest("INVALID_INPUT", "user_id must be greater than 0")
+ }
+ if s == nil || s.entClient == nil || s.userRepo == nil {
+ return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_UNAVAILABLE", "auth identity binding service is unavailable")
+ }
+ if _, err := s.userRepo.GetByID(ctx, userID); err != nil {
+ return nil, err
+ }
+
+ providerType := normalizeAdminAuthIdentityProviderType(input.ProviderType)
+ providerKey := strings.TrimSpace(input.ProviderKey)
+ providerSubject := strings.TrimSpace(input.ProviderSubject)
+ if providerType == "" {
+ return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type must be one of email, linuxdo, oidc, or wechat")
+ }
+ if providerKey == "" || providerSubject == "" {
+ return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type, provider_key, and provider_subject are required")
+ }
+
+ var issuer *string
+ if input.Issuer != nil {
+ trimmed := strings.TrimSpace(*input.Issuer)
+ if trimmed != "" {
+ issuer = &trimmed
+ }
+ }
+
+ channelInput := normalizeAdminBindChannelInput(input.Channel)
+ if input.Channel != nil && channelInput == nil {
+ return nil, infraerrors.BadRequest("INVALID_INPUT", "channel, channel_app_id, and channel_subject are required when channel binding is provided")
+ }
+
+ verifiedAt := time.Now().UTC()
+ tx, err := s.entClient.Tx(ctx)
+ if err != nil {
+ return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_TX_FAILED", "failed to start auth identity bind transaction").WithCause(err)
+ }
+ defer func() { _ = tx.Rollback() }()
+
+ identity, err := tx.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ(providerType),
+ authidentity.ProviderKeyEQ(providerKey),
+ authidentity.ProviderSubjectEQ(providerSubject),
+ ).
+ Only(ctx)
+ if err != nil && !dbent.IsNotFound(err) {
+ return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
+ }
+ if identity != nil && identity.UserID != userID {
+ return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
+ }
+
+ if identity == nil {
+ create := tx.AuthIdentity.Create().
+ SetUserID(userID).
+ SetProviderType(providerType).
+ SetProviderKey(providerKey).
+ SetProviderSubject(providerSubject).
+ SetVerifiedAt(verifiedAt)
+ if issuer != nil {
+ create = create.SetIssuer(*issuer)
+ }
+ if input.Metadata != nil {
+ create = create.SetMetadata(cloneAdminAuthIdentityMetadata(input.Metadata))
+ }
+ identity, err = create.Save(ctx)
+ if err != nil {
+ return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_SAVE_FAILED", "failed to save auth identity").WithCause(err)
+ }
+ } else {
+ update := tx.AuthIdentity.UpdateOneID(identity.ID).SetVerifiedAt(verifiedAt)
+ if issuer != nil {
+ update = update.SetIssuer(*issuer)
+ }
+ if input.Metadata != nil {
+ update = update.SetMetadata(cloneAdminAuthIdentityMetadata(input.Metadata))
+ }
+ identity, err = update.Save(ctx)
+ if err != nil {
+ return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_SAVE_FAILED", "failed to save auth identity").WithCause(err)
+ }
+ }
+
+ var channel *dbent.AuthIdentityChannel
+ if channelInput != nil {
+ channel, err = tx.AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ(providerType),
+ authidentitychannel.ProviderKeyEQ(providerKey),
+ authidentitychannel.ChannelEQ(channelInput.Channel),
+ authidentitychannel.ChannelAppIDEQ(channelInput.ChannelAppID),
+ authidentitychannel.ChannelSubjectEQ(channelInput.ChannelSubject),
+ ).
+ WithIdentity().
+ Only(ctx)
+ if err != nil && !dbent.IsNotFound(err) {
+ return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED", "failed to inspect auth identity channel ownership").WithCause(err)
+ }
+ if channel != nil && channel.Edges.Identity != nil && channel.Edges.Identity.UserID != userID {
+ return nil, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user")
+ }
+ if channel == nil {
+ create := tx.AuthIdentityChannel.Create().
+ SetIdentityID(identity.ID).
+ SetProviderType(providerType).
+ SetProviderKey(providerKey).
+ SetChannel(channelInput.Channel).
+ SetChannelAppID(channelInput.ChannelAppID).
+ SetChannelSubject(channelInput.ChannelSubject)
+ if channelInput.Metadata != nil {
+ create = create.SetMetadata(cloneAdminAuthIdentityMetadata(channelInput.Metadata))
+ }
+ channel, err = create.Save(ctx)
+ if err != nil {
+ return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_SAVE_FAILED", "failed to save auth identity channel").WithCause(err)
+ }
+ } else {
+ update := tx.AuthIdentityChannel.UpdateOneID(channel.ID).SetIdentityID(identity.ID)
+ if channelInput.Metadata != nil {
+ update = update.SetMetadata(cloneAdminAuthIdentityMetadata(channelInput.Metadata))
+ }
+ channel, err = update.Save(ctx)
+ if err != nil {
+ return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_SAVE_FAILED", "failed to save auth identity channel").WithCause(err)
+ }
+ }
+ }
+
+ if err := tx.Commit(); err != nil {
+ return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_COMMIT_FAILED", "failed to commit auth identity bind").WithCause(err)
+ }
+ return buildAdminBoundAuthIdentity(identity, channel), nil
+}
+
func (s *adminServiceImpl) adminSQLDB() (*sql.DB, error) {
if s == nil || s.entClient == nil {
return nil, infraerrors.ServiceUnavailable("ADMIN_SQL_NOT_READY", "admin sql access is not ready")
@@ -906,6 +1084,90 @@ func (s *adminServiceImpl) adminSQLDB() (*sql.DB, error) {
return driver.DB(), nil
}
+func normalizeAdminBindChannelInput(input *AdminBindAuthIdentityChannelInput) *AdminBindAuthIdentityChannelInput {
+ if input == nil {
+ return nil
+ }
+ channel := &AdminBindAuthIdentityChannelInput{
+ Channel: strings.TrimSpace(input.Channel),
+ ChannelAppID: strings.TrimSpace(input.ChannelAppID),
+ ChannelSubject: strings.TrimSpace(input.ChannelSubject),
+ Metadata: cloneAdminAuthIdentityMetadata(input.Metadata),
+ }
+ if channel.Channel == "" || channel.ChannelAppID == "" || channel.ChannelSubject == "" {
+ return nil
+ }
+ return channel
+}
+
+func normalizeAdminAuthIdentityProviderType(input string) string {
+ switch strings.ToLower(strings.TrimSpace(input)) {
+ case "email":
+ return "email"
+ case "linuxdo":
+ return "linuxdo"
+ case "oidc":
+ return "oidc"
+ case "wechat":
+ return "wechat"
+ default:
+ return ""
+ }
+}
+
+func buildAdminBoundAuthIdentity(identity *dbent.AuthIdentity, channel *dbent.AuthIdentityChannel) *AdminBoundAuthIdentity {
+ if identity == nil {
+ return nil
+ }
+ result := &AdminBoundAuthIdentity{
+ UserID: identity.UserID,
+ ProviderType: strings.TrimSpace(identity.ProviderType),
+ ProviderKey: strings.TrimSpace(identity.ProviderKey),
+ ProviderSubject: strings.TrimSpace(identity.ProviderSubject),
+ VerifiedAt: identity.VerifiedAt,
+ Issuer: identity.Issuer,
+ Metadata: cloneAdminAuthIdentityMetadata(identity.Metadata),
+ CreatedAt: identity.CreatedAt,
+ UpdatedAt: identity.UpdatedAt,
+ }
+ if channel != nil {
+ result.Channel = &AdminBoundAuthIdentityChannel{
+ Channel: strings.TrimSpace(channel.Channel),
+ ChannelAppID: strings.TrimSpace(channel.ChannelAppID),
+ ChannelSubject: strings.TrimSpace(channel.ChannelSubject),
+ Metadata: cloneAdminAuthIdentityMetadata(channel.Metadata),
+ CreatedAt: channel.CreatedAt,
+ UpdatedAt: channel.UpdatedAt,
+ }
+ }
+ return result
+}
+
+func cloneAdminAuthIdentityMetadata(input map[string]any) map[string]any {
+ if input == nil {
+ return nil
+ }
+ if len(input) == 0 {
+ return map[string]any{}
+ }
+ data, err := json.Marshal(input)
+ if err != nil {
+ out := make(map[string]any, len(input))
+ for key, value := range input {
+ out[key] = value
+ }
+ return out
+ }
+ var out map[string]any
+ if err := json.Unmarshal(data, &out); err != nil {
+ out = make(map[string]any, len(input))
+ for key, value := range input {
+ out[key] = value
+ }
+ }
+ return out
+}
+
func scanAuthIdentityMigrationReport(scanner interface{ Scan(dest ...any) error }) (AuthIdentityMigrationReport, error) {
var (
report AuthIdentityMigrationReport
diff --git a/backend/internal/service/admin_service_auth_identity_binding_test.go b/backend/internal/service/admin_service_auth_identity_binding_test.go
new file mode 100644
index 00000000..f8ce3935
--- /dev/null
+++ b/backend/internal/service/admin_service_auth_identity_binding_test.go
@@ -0,0 +1,215 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "database/sql"
+ "testing"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/stretchr/testify/require"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
+)
+
+func newAdminServiceAuthIdentityBindingTestClient(t *testing.T) *dbent.Client {
+ t.Helper()
+
+ db, err := sql.Open("sqlite", "file:admin_service_auth_identity_binding?mode=memory&cache=shared&_fk=1")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+ t.Cleanup(func() { _ = client.Close() })
+ return client
+}
+
+func TestAdminServiceBindUserAuthIdentityCreatesCanonicalAndChannelBinding(t *testing.T) {
+ client := newAdminServiceAuthIdentityBindingTestClient(t)
+ ctx := context.Background()
+
+ user, err := client.User.Create().
+ SetEmail("bind-target@example.com").
+ SetPasswordHash("hash").
+ SetRole(RoleUser).
+ SetStatus(StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ svc := &adminServiceImpl{
+ userRepo: &userRepoStub{user: &User{ID: user.ID, Email: user.Email, Status: StatusActive}},
+ entClient: client,
+ }
+
+ result, err := svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-main",
+ ProviderSubject: "union-123",
+ Metadata: map[string]any{"source": "admin-repair"},
+ Channel: &AdminBindAuthIdentityChannelInput{
+ Channel: "open",
+ ChannelAppID: "wx-open",
+ ChannelSubject: "openid-123",
+ Metadata: map[string]any{"scene": "migration"},
+ },
+ })
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, user.ID, result.UserID)
+ require.Equal(t, "wechat", result.ProviderType)
+ require.Equal(t, "wechat-main", result.ProviderKey)
+ require.NotNil(t, result.VerifiedAt)
+ require.NotNil(t, result.Channel)
+ require.Equal(t, "open", result.Channel.Channel)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("wechat"),
+ authidentity.ProviderKeyEQ("wechat-main"),
+ authidentity.ProviderSubjectEQ("union-123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, user.ID, identity.UserID)
+ require.Equal(t, "admin-repair", identity.Metadata["source"])
+ require.NotNil(t, identity.VerifiedAt)
+
+ channel, err := client.AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ("wechat"),
+ authidentitychannel.ProviderKeyEQ("wechat-main"),
+ authidentitychannel.ChannelEQ("open"),
+ authidentitychannel.ChannelAppIDEQ("wx-open"),
+ authidentitychannel.ChannelSubjectEQ("openid-123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, identity.ID, channel.IdentityID)
+ require.Equal(t, "migration", channel.Metadata["scene"])
+}
+
+func TestAdminServiceBindUserAuthIdentityRejectsOtherOwner(t *testing.T) {
+ client := newAdminServiceAuthIdentityBindingTestClient(t)
+ ctx := context.Background()
+
+ owner, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetPasswordHash("hash").
+ SetRole(RoleUser).
+ SetStatus(StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ target, err := client.User.Create().
+ SetEmail("target@example.com").
+ SetPasswordHash("hash").
+ SetRole(RoleUser).
+ SetStatus(StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.AuthIdentity.Create().
+ SetUserID(owner.ID).
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("subject-1").
+ Save(ctx)
+ require.NoError(t, err)
+
+ svc := &adminServiceImpl{
+ userRepo: &userRepoStub{user: &User{ID: target.ID, Email: target.Email, Status: StatusActive}},
+ entClient: client,
+ }
+
+ _, err = svc.BindUserAuthIdentity(ctx, target.ID, AdminBindAuthIdentityInput{
+ ProviderType: "oidc",
+ ProviderKey: "https://issuer.example",
+ ProviderSubject: "subject-1",
+ })
+ require.Error(t, err)
+ require.Equal(t, "AUTH_IDENTITY_OWNERSHIP_CONFLICT", infraerrors.Reason(err))
+}
+
+func TestAdminServiceBindUserAuthIdentityIsIdempotentForSameUser(t *testing.T) {
+ client := newAdminServiceAuthIdentityBindingTestClient(t)
+ ctx := context.Background()
+
+ user, err := client.User.Create().
+ SetEmail("same-user@example.com").
+ SetPasswordHash("hash").
+ SetRole(RoleUser).
+ SetStatus(StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ svc := &adminServiceImpl{
+ userRepo: &userRepoStub{user: &User{ID: user.ID, Email: user.Email, Status: StatusActive}},
+ entClient: client,
+ }
+
+ first, err := svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{
+ ProviderType: "oidc",
+ ProviderKey: "https://issuer.example",
+ ProviderSubject: "subject-2",
+ Metadata: map[string]any{"source": "first"},
+ })
+ require.NoError(t, err)
+
+ second, err := svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{
+ ProviderType: "oidc",
+ ProviderKey: "https://issuer.example",
+ ProviderSubject: "subject-2",
+ Metadata: map[string]any{"source": "second"},
+ })
+ require.NoError(t, err)
+ require.Equal(t, first.UserID, second.UserID)
+ require.Equal(t, "second", second.Metadata["source"])
+
+ identities, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("oidc"),
+ authidentity.ProviderKeyEQ("https://issuer.example"),
+ authidentity.ProviderSubjectEQ("subject-2"),
+ ).
+ All(ctx)
+ require.NoError(t, err)
+ require.Len(t, identities, 1)
+ require.Equal(t, "second", identities[0].Metadata["source"])
+}
+
+func TestAdminServiceBindUserAuthIdentityRejectsInvalidProviderType(t *testing.T) {
+ client := newAdminServiceAuthIdentityBindingTestClient(t)
+ ctx := context.Background()
+
+ user, err := client.User.Create().
+ SetEmail("invalid-provider@example.com").
+ SetPasswordHash("hash").
+ SetRole(RoleUser).
+ SetStatus(StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ svc := &adminServiceImpl{
+ userRepo: &userRepoStub{user: &User{ID: user.ID, Email: user.Email, Status: StatusActive}},
+ entClient: client,
+ }
+
+ _, err = svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{
+ ProviderType: "github",
+ ProviderKey: "github-main",
+ ProviderSubject: "subject-3",
+ })
+ require.Error(t, err)
+ require.Equal(t, "INVALID_INPUT", infraerrors.Reason(err))
+}
--
GitLab
From 724f8e89a1d8a61b9847883dea8573894b6b5c02 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Mon, 20 Apr 2026 22:29:21 +0800
Subject: [PATCH 064/261] feat: resolve auth identity migration reports
---
.../admin/admin_basic_handlers_test.go | 14 ++-
.../handler/admin/admin_service_stub_test.go | 14 +++
.../internal/handler/admin/user_handler.go | 39 ++++++
backend/internal/server/routes/admin.go | 1 +
backend/internal/service/admin_service.go | 111 ++++++++++++++++--
..._service_identity_migration_report_test.go | 34 +++++-
...h_identity_migration_report_resolution.sql | 11 ++
7 files changed, 209 insertions(+), 15 deletions(-)
create mode 100644 backend/migrations/114_auth_identity_migration_report_resolution.sql
diff --git a/backend/internal/handler/admin/admin_basic_handlers_test.go b/backend/internal/handler/admin/admin_basic_handlers_test.go
index 57620005..207931d9 100644
--- a/backend/internal/handler/admin/admin_basic_handlers_test.go
+++ b/backend/internal/handler/admin/admin_basic_handlers_test.go
@@ -7,6 +7,7 @@ import (
"net/http/httptest"
"testing"
+ servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
@@ -24,6 +25,10 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
router.GET("/api/v1/admin/users", userHandler.List)
router.GET("/api/v1/admin/users/auth-identity-migration-reports/summary", userHandler.GetAuthIdentityMigrationReportSummary)
router.GET("/api/v1/admin/users/auth-identity-migration-reports", userHandler.ListAuthIdentityMigrationReports)
+ router.POST("/api/v1/admin/users/auth-identity-migration-reports/:id/resolve", func(c *gin.Context) {
+ c.Set(string(servermiddleware.ContextKeyUser), servermiddleware.AuthSubject{UserID: 99})
+ userHandler.ResolveAuthIdentityMigrationReport(c)
+ })
router.GET("/api/v1/admin/users/:id", userHandler.GetByID)
router.POST("/api/v1/admin/users/:id/auth-identities", userHandler.BindAuthIdentity)
router.POST("/api/v1/admin/users", userHandler.Create)
@@ -83,6 +88,13 @@ func TestUserHandlerEndpoints(t *testing.T) {
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
+ body, _ := json.Marshal(map[string]any{"resolution_note": "resolved by manual bind"})
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users/auth-identity-migration-reports/1/resolve", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusOK, rec.Code)
+
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/1", nil)
router.ServeHTTP(rec, req)
@@ -99,7 +111,7 @@ func TestUserHandlerEndpoints(t *testing.T) {
"channel_subject": "openid-123",
},
}
- body, _ := json.Marshal(bindBody)
+ body, _ = json.Marshal(bindBody)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users/1/auth-identities", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go
index c8c7a247..8ecadfdf 100644
--- a/backend/internal/handler/admin/admin_service_stub_test.go
+++ b/backend/internal/handler/admin/admin_service_stub_test.go
@@ -249,6 +249,20 @@ func (s *stubAdminService) BindUserAuthIdentity(ctx context.Context, userID int6
return result, nil
}
+func (s *stubAdminService) ResolveAuthIdentityMigrationReport(ctx context.Context, reportID, resolvedByUserID int64, resolutionNote string) (*service.AuthIdentityMigrationReport, error) {
+ now := time.Now().UTC()
+ for i := range s.migrationReports {
+ if s.migrationReports[i].ID != reportID {
+ continue
+ }
+ s.migrationReports[i].ResolvedAt = &now
+ s.migrationReports[i].ResolvedByUserID = &resolvedByUserID
+ s.migrationReports[i].ResolutionNote = resolutionNote
+ return &s.migrationReports[i], nil
+ }
+ return nil, nil
+}
+
func (s *stubAdminService) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]service.Group, int64, error) {
return s.groups, int64(len(s.groups)), nil
}
diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go
index 321214af..e582e322 100644
--- a/backend/internal/handler/admin/user_handler.go
+++ b/backend/internal/handler/admin/user_handler.go
@@ -7,6 +7,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -82,6 +83,10 @@ type BindUserAuthIdentityChannelRequest struct {
Metadata map[string]any `json:"metadata"`
}
+type ResolveAuthIdentityMigrationReportRequest struct {
+ ResolutionNote string `json:"resolution_note"`
+}
+
// List handles listing all users with pagination
// GET /api/v1/admin/users
// Query params:
@@ -252,6 +257,40 @@ func (h *UserHandler) BindAuthIdentity(c *gin.Context) {
response.Success(c, result)
}
+// ResolveAuthIdentityMigrationReport marks a migration report as resolved.
+// POST /api/v1/admin/users/auth-identity-migration-reports/:id/resolve
+func (h *UserHandler) ResolveAuthIdentityMigrationReport(c *gin.Context) {
+ reportID, err := strconv.ParseInt(c.Param("id"), 10, 64)
+ if err != nil {
+ response.BadRequest(c, "Invalid report ID")
+ return
+ }
+
+ subject, ok := servermiddleware.GetAuthSubjectFromContext(c)
+ if !ok || subject.UserID <= 0 {
+ response.Unauthorized(c, "Authentication required")
+ return
+ }
+
+ var req ResolveAuthIdentityMigrationReportRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ report, err := h.adminService.ResolveAuthIdentityMigrationReport(
+ c.Request.Context(),
+ reportID,
+ subject.UserID,
+ req.ResolutionNote,
+ )
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, report)
+}
+
// Create handles creating a new user
// POST /api/v1/admin/users
func (h *UserHandler) Create(c *gin.Context) {
diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go
index c78fba33..e5c0eac1 100644
--- a/backend/internal/server/routes/admin.go
+++ b/backend/internal/server/routes/admin.go
@@ -212,6 +212,7 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
{
users.GET("/auth-identity-migration-reports/summary", h.Admin.User.GetAuthIdentityMigrationReportSummary)
users.GET("/auth-identity-migration-reports", h.Admin.User.ListAuthIdentityMigrationReports)
+ users.POST("/auth-identity-migration-reports/:id/resolve", h.Admin.User.ResolveAuthIdentityMigrationReport)
users.GET("", h.Admin.User.List)
users.GET("/:id", h.Admin.User.GetByID)
users.POST("/:id/auth-identities", h.Admin.User.BindAuthIdentity)
diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go
index 9ff26861..3490374e 100644
--- a/backend/internal/service/admin_service.go
+++ b/backend/internal/service/admin_service.go
@@ -42,6 +42,7 @@ type AdminService interface {
ListAuthIdentityMigrationReports(ctx context.Context, reportType string, page, pageSize int) ([]AuthIdentityMigrationReport, int64, error)
GetAuthIdentityMigrationReportSummary(ctx context.Context) (*AuthIdentityMigrationReportSummary, error)
BindUserAuthIdentity(ctx context.Context, userID int64, input AdminBindAuthIdentityInput) (*AdminBoundAuthIdentity, error)
+ ResolveAuthIdentityMigrationReport(ctx context.Context, reportID, resolvedByUserID int64, resolutionNote string) (*AuthIdentityMigrationReport, error)
// Group management
ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]Group, int64, error)
@@ -137,16 +138,21 @@ type UpdateUserInput struct {
}
type AuthIdentityMigrationReport struct {
- ID int64 `json:"id"`
- ReportType string `json:"report_type"`
- ReportKey string `json:"report_key"`
- Details map[string]any `json:"details"`
- CreatedAt time.Time `json:"created_at"`
+ ID int64 `json:"id"`
+ ReportType string `json:"report_type"`
+ ReportKey string `json:"report_key"`
+ Details map[string]any `json:"details"`
+ CreatedAt time.Time `json:"created_at"`
+ ResolvedAt *time.Time `json:"resolved_at,omitempty"`
+ ResolvedByUserID *int64 `json:"resolved_by_user_id,omitempty"`
+ ResolutionNote string `json:"resolution_note,omitempty"`
}
type AuthIdentityMigrationReportSummary struct {
- Total int64 `json:"total"`
- ByType map[string]int64 `json:"by_type"`
+ Total int64 `json:"total"`
+ OpenTotal int64 `json:"open_total"`
+ ResolvedTotal int64 `json:"resolved_total"`
+ ByType map[string]int64 `json:"by_type"`
}
type AdminBindAuthIdentityInput struct {
@@ -874,7 +880,7 @@ WHERE ($1 = '' OR report_type = $1)`,
}
rows, err := db.QueryContext(ctx, `
-SELECT id, report_type, report_key, details, created_at
+SELECT id, report_type, report_key, details, created_at, resolved_at, resolved_by_user_id, resolution_note
FROM auth_identity_migration_reports
WHERE ($1 = '' OR report_type = $1)
ORDER BY created_at DESC, id DESC
@@ -909,7 +915,11 @@ func (s *adminServiceImpl) GetAuthIdentityMigrationReportSummary(ctx context.Con
}
rows, err := db.QueryContext(ctx, `
-SELECT report_type, COUNT(*)
+SELECT
+ report_type,
+ COUNT(*),
+ SUM(CASE WHEN resolved_at IS NULL THEN 1 ELSE 0 END),
+ SUM(CASE WHEN resolved_at IS NOT NULL THEN 1 ELSE 0 END)
FROM auth_identity_migration_reports
GROUP BY report_type
ORDER BY report_type ASC`)
@@ -924,11 +934,15 @@ ORDER BY report_type ASC`)
for rows.Next() {
var reportType string
var count int64
- if err := rows.Scan(&reportType, &count); err != nil {
+ var openCount int64
+ var resolvedCount int64
+ if err := rows.Scan(&reportType, &count, &openCount, &resolvedCount); err != nil {
return nil, err
}
summary.ByType[reportType] = count
summary.Total += count
+ summary.OpenTotal += openCount
+ summary.ResolvedTotal += resolvedCount
}
if err := rows.Err(); err != nil {
return nil, err
@@ -936,6 +950,56 @@ ORDER BY report_type ASC`)
return summary, nil
}
+func (s *adminServiceImpl) ResolveAuthIdentityMigrationReport(ctx context.Context, reportID, resolvedByUserID int64, resolutionNote string) (*AuthIdentityMigrationReport, error) {
+ if reportID <= 0 {
+ return nil, infraerrors.BadRequest("INVALID_INPUT", "report id must be greater than 0")
+ }
+ if resolvedByUserID <= 0 {
+ return nil, infraerrors.BadRequest("INVALID_INPUT", "resolved_by_user_id must be greater than 0")
+ }
+
+ db, err := s.adminSQLDB()
+ if err != nil {
+ return nil, err
+ }
+
+ now := time.Now().UTC()
+ result, err := db.ExecContext(ctx, `
+UPDATE auth_identity_migration_reports
+SET
+ resolved_at = COALESCE(resolved_at, $2),
+ resolved_by_user_id = COALESCE(resolved_by_user_id, $3),
+ resolution_note = $4
+WHERE id = $1`,
+ reportID,
+ now,
+ resolvedByUserID,
+ strings.TrimSpace(resolutionNote),
+ )
+ if err != nil {
+ return nil, err
+ }
+ affected, err := result.RowsAffected()
+ if err != nil {
+ return nil, err
+ }
+ if affected == 0 {
+ return nil, infraerrors.NotFound("AUTH_IDENTITY_MIGRATION_REPORT_NOT_FOUND", "auth identity migration report not found")
+ }
+
+ row := db.QueryRowContext(ctx, `
+SELECT id, report_type, report_key, details, created_at, resolved_at, resolved_by_user_id, resolution_note
+FROM auth_identity_migration_reports
+WHERE id = $1`,
+ reportID,
+ )
+ report, err := scanAuthIdentityMigrationReport(row)
+ if err != nil {
+ return nil, err
+ }
+ return &report, nil
+}
+
func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int64, input AdminBindAuthIdentityInput) (*AdminBoundAuthIdentity, error) {
if userID <= 0 {
return nil, infraerrors.BadRequest("INVALID_INPUT", "user_id must be greater than 0")
@@ -1170,10 +1234,22 @@ func cloneAdminAuthIdentityMetadata(input map[string]any) map[string]any {
func scanAuthIdentityMigrationReport(scanner interface{ Scan(dest ...any) error }) (AuthIdentityMigrationReport, error) {
var (
- report AuthIdentityMigrationReport
- details []byte
+ report AuthIdentityMigrationReport
+ details []byte
+ resolvedAt sql.NullTime
+ resolvedByUserID sql.NullInt64
+ resolutionNote sql.NullString
)
- if err := scanner.Scan(&report.ID, &report.ReportType, &report.ReportKey, &details, &report.CreatedAt); err != nil {
+ if err := scanner.Scan(
+ &report.ID,
+ &report.ReportType,
+ &report.ReportKey,
+ &details,
+ &report.CreatedAt,
+ &resolvedAt,
+ &resolvedByUserID,
+ &resolutionNote,
+ ); err != nil {
return AuthIdentityMigrationReport{}, err
}
report.Details = map[string]any{}
@@ -1182,6 +1258,15 @@ func scanAuthIdentityMigrationReport(scanner interface{ Scan(dest ...any) error
return AuthIdentityMigrationReport{}, err
}
}
+ if resolvedAt.Valid {
+ report.ResolvedAt = &resolvedAt.Time
+ }
+ if resolvedByUserID.Valid {
+ report.ResolvedByUserID = &resolvedByUserID.Int64
+ }
+ if resolutionNote.Valid {
+ report.ResolutionNote = resolutionNote.String
+ }
return report, nil
}
diff --git a/backend/internal/service/admin_service_identity_migration_report_test.go b/backend/internal/service/admin_service_identity_migration_report_test.go
index 75ca3e5a..6975604b 100644
--- a/backend/internal/service/admin_service_identity_migration_report_test.go
+++ b/backend/internal/service/admin_service_identity_migration_report_test.go
@@ -30,7 +30,10 @@ func newAdminServiceMigrationReportTestClient(t *testing.T) *dbent.Client {
report_type TEXT NOT NULL,
report_key TEXT NOT NULL,
details TEXT NOT NULL DEFAULT '{}',
- created_at DATETIME NOT NULL
+ created_at DATETIME NOT NULL,
+ resolved_at DATETIME NULL,
+ resolved_by_user_id INTEGER NULL,
+ resolution_note TEXT NOT NULL DEFAULT ''
)`)
require.NoError(t, err)
@@ -87,6 +90,35 @@ VALUES
summary, err := svc.GetAuthIdentityMigrationReportSummary(context.Background())
require.NoError(t, err)
require.Equal(t, int64(3), summary.Total)
+ require.Equal(t, int64(3), summary.OpenTotal)
+ require.Zero(t, summary.ResolvedTotal)
require.Equal(t, int64(1), summary.ByType["oidc_synthetic_email_requires_manual_recovery"])
require.Equal(t, int64(2), summary.ByType["wechat_provider_key_conflict"])
}
+
+func TestAdminServiceResolveAuthIdentityMigrationReport(t *testing.T) {
+ client := newAdminServiceMigrationReportTestClient(t)
+ driver, ok := client.Driver().(*entsql.Driver)
+ require.True(t, ok)
+
+ now := time.Now().UTC()
+ _, err := driver.DB().ExecContext(context.Background(), `
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details, created_at)
+VALUES ($1, $2, $3, $4)`,
+ "oidc_synthetic_email_requires_manual_recovery", "u-1", `{"user_id":1}`, now,
+ )
+ require.NoError(t, err)
+
+ svc := &adminServiceImpl{entClient: client}
+ report, err := svc.ResolveAuthIdentityMigrationReport(context.Background(), 1, 99, "resolved by admin binding")
+ require.NoError(t, err)
+ require.NotNil(t, report.ResolvedAt)
+ require.NotNil(t, report.ResolvedByUserID)
+ require.Equal(t, int64(99), *report.ResolvedByUserID)
+ require.Equal(t, "resolved by admin binding", report.ResolutionNote)
+
+ summary, err := svc.GetAuthIdentityMigrationReportSummary(context.Background())
+ require.NoError(t, err)
+ require.Zero(t, summary.OpenTotal)
+ require.Equal(t, int64(1), summary.ResolvedTotal)
+}
diff --git a/backend/migrations/114_auth_identity_migration_report_resolution.sql b/backend/migrations/114_auth_identity_migration_report_resolution.sql
new file mode 100644
index 00000000..f84bf822
--- /dev/null
+++ b/backend/migrations/114_auth_identity_migration_report_resolution.sql
@@ -0,0 +1,11 @@
+ALTER TABLE auth_identity_migration_reports
+ ADD COLUMN IF NOT EXISTS resolved_at TIMESTAMPTZ NULL;
+
+ALTER TABLE auth_identity_migration_reports
+ ADD COLUMN IF NOT EXISTS resolved_by_user_id BIGINT NULL;
+
+ALTER TABLE auth_identity_migration_reports
+ ADD COLUMN IF NOT EXISTS resolution_note TEXT NOT NULL DEFAULT '';
+
+CREATE INDEX IF NOT EXISTS idx_auth_identity_migration_reports_resolved_at
+ ON auth_identity_migration_reports (resolved_at);
--
GitLab
From bffcc2042eeeb61eb0e5c9fc0c7cd4d40ff89f96 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Mon, 20 Apr 2026 22:37:25 +0800
Subject: [PATCH 065/261] fix: complete oidc pending auth callback flows
---
frontend/src/views/auth/OidcCallbackView.vue | 420 ++++++++++++++----
.../auth/__tests__/OidcCallbackView.spec.ts | 187 +++++---
2 files changed, 457 insertions(+), 150 deletions(-)
diff --git a/frontend/src/views/auth/OidcCallbackView.vue b/frontend/src/views/auth/OidcCallbackView.vue
index deb40b20..de3f2e40 100644
--- a/frontend/src/views/auth/OidcCallbackView.vue
+++ b/frontend/src/views/auth/OidcCallbackView.vue
@@ -18,9 +18,10 @@
@@ -108,46 +109,131 @@
-
+
- Continue with email to finish setting up your {{ providerName }} sign-in.
+ Review the {{ providerName }} profile details before continuing.
-
+
+ {{ isSubmitting ? t('common.processing') : 'Continue' }}
+
+
+
+
+
+ Enter an email address to create your account and continue.
+
+
+
+ {{ isSubmitting ? t('common.processing') : 'Create account' }}
+
+
+ I already have an account
+
-
- Continue with email
-
+
+
+ {{ accountActionError }}
+
+
-
+
- Sign in to bind {{ providerName }} to the existing account for
- {{ pendingEmail }} .
+ Log in to an existing account to bind this {{ providerName }} sign-in.
-
- Sign in to bind
-
+
+
+
+
+ {{ isSubmitting ? t('common.processing') : 'Log in and bind' }}
+
+
+ Use a different email
+
+
+
+
+ {{ accountActionError }}
+
+
-
+
- Review the {{ providerName }} profile details before continuing.
+ Enter the 6-digit verification code for
+ {{ totpUserEmailMasked || 'your account' }}
+ to finish binding this {{ providerName }} sign-in.
-
- {{ isSubmitting ? t('common.processing') : 'Continue' }}
-
+
+
+
+ {{ isSubmitting ? t('common.processing') : 'Verify and continue' }}
+
+
+
+
+ {{ totpError }}
+
+
@@ -177,11 +263,12 @@
diff --git a/frontend/src/views/auth/__tests__/WechatPaymentCallbackView.spec.ts b/frontend/src/views/auth/__tests__/WechatPaymentCallbackView.spec.ts
new file mode 100644
index 00000000..cfbd9f1c
--- /dev/null
+++ b/frontend/src/views/auth/__tests__/WechatPaymentCallbackView.spec.ts
@@ -0,0 +1,80 @@
+import { flushPromises, mount } from '@vue/test-utils'
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+import WechatPaymentCallbackView from '@/views/auth/WechatPaymentCallbackView.vue'
+
+const { replaceMock, routeState, locationState } = vi.hoisted(() => ({
+ replaceMock: vi.fn(),
+ routeState: {
+ query: {} as Record,
+ },
+ locationState: {
+ current: {
+ href: 'http://localhost/auth/wechat/payment/callback',
+ hash: '',
+ search: '',
+ pathname: '/auth/wechat/payment/callback',
+ origin: 'http://localhost',
+ } as Location & { origin: string },
+ },
+}))
+
+vi.mock('vue-router', () => ({
+ useRoute: () => routeState,
+ useRouter: () => ({
+ replace: replaceMock,
+ }),
+}))
+
+vi.mock('vue-i18n', () => ({
+ useI18n: () => ({
+ t: (key: string) => key,
+ locale: { value: 'zh-CN' },
+ }),
+}))
+
+describe('WechatPaymentCallbackView', () => {
+ beforeEach(() => {
+ replaceMock.mockReset()
+ routeState.query = {}
+ locationState.current = {
+ href: 'http://localhost/auth/wechat/payment/callback',
+ hash: '',
+ search: '',
+ pathname: '/auth/wechat/payment/callback',
+ origin: 'http://localhost',
+ } as Location & { origin: string }
+ Object.defineProperty(window, 'location', {
+ configurable: true,
+ value: locationState.current,
+ })
+ })
+
+ it('redirects back to purchase with openid and payment context from hash fragment', async () => {
+ locationState.current.hash = '#openid=openid-123&payment_type=wxpay&amount=12.5&order_type=balance&redirect=%2Fpurchase%3Ffrom%3Dwechat'
+
+ mount(WechatPaymentCallbackView)
+ await flushPromises()
+
+ expect(replaceMock).toHaveBeenCalledWith({
+ path: '/purchase',
+ query: {
+ from: 'wechat',
+ wechat_resume: '1',
+ openid: 'openid-123',
+ payment_type: 'wxpay',
+ amount: '12.5',
+ order_type: 'balance',
+ },
+ })
+ })
+
+ it('shows an error when the callback payload is missing openid', async () => {
+ locationState.current.hash = '#payment_type=wxpay'
+
+ const wrapper = mount(WechatPaymentCallbackView)
+ await flushPromises()
+
+ expect(replaceMock).not.toHaveBeenCalled()
+ expect(wrapper.text()).toContain('微信支付回调缺少 openid。')
+ })
+})
diff --git a/frontend/src/views/user/PaymentView.vue b/frontend/src/views/user/PaymentView.vue
index f973ad5b..bfb9dae2 100644
--- a/frontend/src/views/user/PaymentView.vue
+++ b/frontend/src/views/user/PaymentView.vue
@@ -309,6 +309,20 @@ const previewImage = ref('')
const paymentPhase = ref<'select' | 'paying'>('select')
+interface CreateOrderOptions {
+ openid?: string
+ paymentType?: string
+ isResume?: boolean
+}
+
+interface WeixinJSBridgeLike {
+ invoke(
+ action: string,
+ payload: Record,
+ callback: (result: Record) => void,
+ ): void
+}
+
function emptyPaymentState(): PaymentRecoverySnapshot {
return {
orderId: 0,
@@ -326,6 +340,48 @@ function emptyPaymentState(): PaymentRecoverySnapshot {
}
}
+function readRouteQueryValue(value: unknown): string {
+ if (Array.isArray(value)) {
+ return typeof value[0] === 'string' ? value[0] : ''
+ }
+ return typeof value === 'string' ? value : ''
+}
+
+function getWeixinJSBridge(): WeixinJSBridgeLike | undefined {
+ return (window as Window & { WeixinJSBridge?: WeixinJSBridgeLike }).WeixinJSBridge
+}
+
+function waitForWeixinJSBridge(timeoutMs = 4000): Promise {
+ const existing = getWeixinJSBridge()
+ if (existing) return Promise.resolve(existing)
+
+ return new Promise((resolve) => {
+ let settled = false
+ const finish = (bridge: WeixinJSBridgeLike | null) => {
+ if (settled) return
+ settled = true
+ document.removeEventListener('WeixinJSBridgeReady', handleReady)
+ document.removeEventListener('onWeixinJSBridgeReady', handleReady)
+ window.clearTimeout(timer)
+ resolve(bridge)
+ }
+ const handleReady = () => finish(getWeixinJSBridge() ?? null)
+ const timer = window.setTimeout(() => finish(getWeixinJSBridge() ?? null), timeoutMs)
+ document.addEventListener('WeixinJSBridgeReady', handleReady, false)
+ document.addEventListener('onWeixinJSBridgeReady', handleReady, false)
+ })
+}
+
+async function invokeWechatJsapiPayment(payload: Record): Promise> {
+ const bridge = await waitForWeixinJSBridge()
+ if (!bridge) {
+ throw new Error('WeixinJSBridge is unavailable')
+ }
+ return new Promise((resolve) => {
+ bridge.invoke('getBrandWCPayRequest', payload, (result) => resolve(result || {}))
+ })
+}
+
const paymentState = ref(emptyPaymentState())
function persistRecoverySnapshot(snapshot: PaymentRecoverySnapshot) {
@@ -560,25 +616,32 @@ async function confirmSubscribe() {
await createOrder(selectedPlan.value.price, 'subscription', selectedPlan.value.id)
}
-async function createOrder(orderAmount: number, orderType: OrderType, planId?: number) {
+async function createOrder(orderAmount: number, orderType: OrderType, planId?: number, options: CreateOrderOptions = {}) {
submitting.value = true
errorMessage.value = ''
try {
- const result = await paymentStore.createOrder(buildCreateOrderPayload({
+ const requestType = normalizeVisibleMethod(options.paymentType || selectedMethod.value) || options.paymentType || selectedMethod.value
+ const payload = buildCreateOrderPayload({
amount: orderAmount,
- paymentType: selectedMethod.value,
+ paymentType: requestType,
orderType,
planId,
origin: typeof window !== 'undefined' ? window.location.origin : '',
isWechatBrowser: typeof window !== 'undefined' && /MicroMessenger/i.test(window.navigator.userAgent),
- })) as CreateOrderResult & { resume_token?: string }
+ })
+ if (options.openid) {
+ payload.openid = options.openid
+ }
+ payload.is_mobile = isMobileDevice()
+
+ const result = await paymentStore.createOrder(payload) as CreateOrderResult & { resume_token?: string }
const openWindow = (url: string, features = POPUP_WINDOW_FEATURES) => {
const win = window.open(url, 'paymentPopup', features)
if (!win || win.closed) {
window.location.href = url
}
}
- const visibleMethod = normalizeVisibleMethod(selectedMethod.value) || selectedMethod.value
+ const visibleMethod = normalizeVisibleMethod(requestType) || requestType
const stripeMethod = visibleMethod === 'wxpay' ? 'wechat_pay' : 'alipay'
const stripeRouteUrl = result.client_secret
? router.resolve({
@@ -599,6 +662,11 @@ async function createOrder(orderAmount: number, orderType: OrderType, planId?: n
stripeRouteUrl,
})
+ if (decision.kind === 'wechat_oauth' && decision.oauth?.authorize_url) {
+ window.location.href = decision.oauth.authorize_url
+ return
+ }
+
if (decision.kind === 'unhandled') {
errorMessage.value = t('payment.result.failed')
appStore.showError(errorMessage.value)
@@ -617,6 +685,16 @@ async function createOrder(orderAmount: number, orderType: OrderType, planId?: n
window.location.href = decision.paymentState.payUrl
return
}
+ if (decision.kind === 'wechat_jsapi' && decision.jsapi) {
+ const jsapiResult = await invokeWechatJsapiPayment(decision.jsapi as Record)
+ const errMsg = String(jsapiResult.err_msg || '').toLowerCase()
+ if (errMsg.includes('cancel')) {
+ appStore.showInfo(t('payment.qr.cancelled'))
+ } else if (errMsg && !errMsg.includes('ok')) {
+ appStore.showError(t('payment.result.failed'))
+ }
+ return
+ }
if (decision.kind === 'redirect_waiting' && decision.paymentState.payUrl) {
if (isMobileDevice()) {
window.location.href = decision.paymentState.payUrl
@@ -640,6 +718,50 @@ async function createOrder(orderAmount: number, orderType: OrderType, planId?: n
}
}
+async function resumeWechatPaymentFromQuery() {
+ const openid = readRouteQueryValue(route.query.openid)
+ if (readRouteQueryValue(route.query.wechat_resume) !== '1' || !openid) {
+ return
+ }
+
+ const paymentType = normalizeVisibleMethod(readRouteQueryValue(route.query.payment_type)) || 'wxpay'
+ const orderType = readRouteQueryValue(route.query.order_type) === 'subscription' ? 'subscription' : 'balance'
+ const planId = Number.parseInt(readRouteQueryValue(route.query.plan_id), 10)
+ const rawAmount = Number.parseFloat(readRouteQueryValue(route.query.amount))
+ const orderAmount = Number.isFinite(rawAmount) && rawAmount > 0
+ ? rawAmount
+ : (orderType === 'subscription'
+ ? (checkout.value.plans.find(plan => plan.id === planId)?.price ?? 0)
+ : validAmount.value)
+
+ selectedMethod.value = paymentType
+ if (orderType === 'balance' && orderAmount > 0) {
+ amount.value = orderAmount
+ }
+ if (orderType === 'subscription' && Number.isFinite(planId) && planId > 0) {
+ selectedPlan.value = checkout.value.plans.find(plan => plan.id === planId) ?? null
+ }
+
+ const nextQuery = { ...route.query }
+ delete nextQuery.wechat_resume
+ delete nextQuery.openid
+ delete nextQuery.state
+ delete nextQuery.scope
+ delete nextQuery.payment_type
+ delete nextQuery.amount
+ delete nextQuery.order_type
+ delete nextQuery.plan_id
+ await router.replace({ path: route.path, query: nextQuery })
+
+ if (orderAmount > 0) {
+ await createOrder(orderAmount, orderType, Number.isFinite(planId) && planId > 0 ? planId : undefined, {
+ openid,
+ paymentType,
+ isResume: true,
+ })
+ }
+}
+
onMounted(async () => {
try {
const res = await paymentAPI.getCheckoutInfo()
@@ -672,6 +794,7 @@ onMounted(async () => {
removeRecoverySnapshot()
}
}
+ await resumeWechatPaymentFromQuery()
if (checkout.value.balance_disabled) {
activeTab.value = 'subscription'
}
--
GitLab
From a1425b457d91f59c526a2f56cc4f7c78b8c3c3c6 Mon Sep 17 00:00:00 2001
From: erio
Date: Mon, 20 Apr 2026 23:38:59 +0800
Subject: [PATCH 071/261] feat(channel-monitor): redesign user dashboard as
card grid
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Reference check-cx UI: INTELLIGENCE MONITOR hero + 3-column card grid
with 60-point timeline bars.
Backend:
- Add PrimaryPingLatencyMs + Timeline[60] to UserMonitorView
- ListRecentHistoryForMonitors: batch CTE + ROW_NUMBER() window query
- indexLatestByModel / indexAvailabilityByModel helpers
Frontend:
- 7 new components: ProviderIcon, MonitorMetricPair, MonitorAvailabilityRow,
MonitorTimeline, MonitorHero, MonitorCard, MonitorCardGrid
- ChannelStatusView 381→~180 lines (delegated to subcomponents)
- AbortController reload concurrency protection
- HSL 0-120° availability color mapping
- Replace emoji with Icon component (bolt / globe)
- i18n: monitorCommon.* shared namespace, channelStatus.hero.*
Bump VERSION to 0.1.114.24
---
.../handler/channel_monitor_user_handler.go | 60 ++--
.../repository/channel_monitor_repo.go | 135 ++++++++-
.../service/channel_monitor_aggregator.go | 95 ++++++-
.../internal/service/channel_monitor_const.go | 3 +
.../service/channel_monitor_service.go | 3 +
.../internal/service/channel_monitor_types.go | 37 ++-
frontend/src/api/channelMonitor.ts | 9 +
.../user/MonitorPrimaryModelCell.vue | 71 -----
.../user/monitor/MonitorAvailabilityRow.vue | 49 ++++
.../components/user/monitor/MonitorCard.vue | 128 +++++++++
.../user/monitor/MonitorCardGrid.vue | 81 ++++++
.../components/user/monitor/MonitorHero.vue | 133 +++++++++
.../user/monitor/MonitorMetricPair.vue | 45 +++
.../user/monitor/MonitorTimeline.vue | 113 ++++++++
.../components/user/monitor/ProviderIcon.vue | 71 +++++
.../composables/useChannelMonitorFormat.ts | 60 +++-
frontend/src/i18n/locales/en.ts | 33 ++-
frontend/src/i18n/locales/zh.ts | 33 ++-
frontend/src/views/user/ChannelStatusView.vue | 257 ++++++++----------
19 files changed, 1136 insertions(+), 280 deletions(-)
delete mode 100644 frontend/src/components/user/MonitorPrimaryModelCell.vue
create mode 100644 frontend/src/components/user/monitor/MonitorAvailabilityRow.vue
create mode 100644 frontend/src/components/user/monitor/MonitorCard.vue
create mode 100644 frontend/src/components/user/monitor/MonitorCardGrid.vue
create mode 100644 frontend/src/components/user/monitor/MonitorHero.vue
create mode 100644 frontend/src/components/user/monitor/MonitorMetricPair.vue
create mode 100644 frontend/src/components/user/monitor/MonitorTimeline.vue
create mode 100644 frontend/src/components/user/monitor/ProviderIcon.vue
diff --git a/backend/internal/handler/channel_monitor_user_handler.go b/backend/internal/handler/channel_monitor_user_handler.go
index a031b4a2..6a513dc1 100644
--- a/backend/internal/handler/channel_monitor_user_handler.go
+++ b/backend/internal/handler/channel_monitor_user_handler.go
@@ -1,6 +1,8 @@
package handler
import (
+ "time"
+
"github.com/Wei-Shaw/sub2api/internal/handler/admin"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
@@ -22,15 +24,26 @@ func NewChannelMonitorUserHandler(monitorService *service.ChannelMonitorService)
// --- Response ---
type channelMonitorUserListItem struct {
- ID int64 `json:"id"`
- Name string `json:"name"`
- Provider string `json:"provider"`
- GroupName string `json:"group_name"`
- PrimaryModel string `json:"primary_model"`
- PrimaryStatus string `json:"primary_status"`
- PrimaryLatencyMs *int `json:"primary_latency_ms"`
- Availability7d float64 `json:"availability_7d"`
- ExtraModels []dto.ChannelMonitorExtraModelStatus `json:"extra_models"`
+ ID int64 `json:"id"`
+ Name string `json:"name"`
+ Provider string `json:"provider"`
+ GroupName string `json:"group_name"`
+ PrimaryModel string `json:"primary_model"`
+ PrimaryStatus string `json:"primary_status"`
+ PrimaryLatencyMs *int `json:"primary_latency_ms"`
+ PrimaryPingLatencyMs *int `json:"primary_ping_latency_ms"`
+ Availability7d float64 `json:"availability_7d"`
+ ExtraModels []dto.ChannelMonitorExtraModelStatus `json:"extra_models"`
+ Timeline []channelMonitorUserTimelinePoint `json:"timeline"`
+}
+
+// channelMonitorUserTimelinePoint 主模型最近一次检测的 timeline 点。
+// 仅用于用户视图 list 响应,admin 视图不使用。
+type channelMonitorUserTimelinePoint struct {
+ Status string `json:"status"`
+ LatencyMs *int `json:"latency_ms"`
+ PingLatencyMs *int `json:"ping_latency_ms"`
+ CheckedAt string `json:"checked_at"`
}
type channelMonitorUserDetailResponse struct {
@@ -60,16 +73,27 @@ func userMonitorViewToItem(v *service.UserMonitorView) channelMonitorUserListIte
LatencyMs: e.LatencyMs,
})
}
+ timeline := make([]channelMonitorUserTimelinePoint, 0, len(v.Timeline))
+ for _, p := range v.Timeline {
+ timeline = append(timeline, channelMonitorUserTimelinePoint{
+ Status: p.Status,
+ LatencyMs: p.LatencyMs,
+ PingLatencyMs: p.PingLatencyMs,
+ CheckedAt: p.CheckedAt.UTC().Format(time.RFC3339),
+ })
+ }
return channelMonitorUserListItem{
- ID: v.ID,
- Name: v.Name,
- Provider: v.Provider,
- GroupName: v.GroupName,
- PrimaryModel: v.PrimaryModel,
- PrimaryStatus: v.PrimaryStatus,
- PrimaryLatencyMs: v.PrimaryLatencyMs,
- Availability7d: v.Availability7d,
- ExtraModels: extras,
+ ID: v.ID,
+ Name: v.Name,
+ Provider: v.Provider,
+ GroupName: v.GroupName,
+ PrimaryModel: v.PrimaryModel,
+ PrimaryStatus: v.PrimaryStatus,
+ PrimaryLatencyMs: v.PrimaryLatencyMs,
+ PrimaryPingLatencyMs: v.PrimaryPingLatencyMs,
+ Availability7d: v.Availability7d,
+ ExtraModels: extras,
+ Timeline: timeline,
}
}
diff --git a/backend/internal/repository/channel_monitor_repo.go b/backend/internal/repository/channel_monitor_repo.go
index b943f33c..cf5e1a93 100644
--- a/backend/internal/repository/channel_monitor_repo.go
+++ b/backend/internal/repository/channel_monitor_repo.go
@@ -243,7 +243,7 @@ func (r *channelMonitorRepository) ListHistory(ctx context.Context, monitorID in
func (r *channelMonitorRepository) ListLatestPerModel(ctx context.Context, monitorID int64) ([]*service.ChannelMonitorLatest, error) {
const q = `
SELECT DISTINCT ON (model)
- model, status, latency_ms, checked_at
+ model, status, latency_ms, ping_latency_ms, checked_at
FROM channel_monitor_histories
WHERE monitor_id = $1
ORDER BY model, checked_at DESC
@@ -257,19 +257,27 @@ func (r *channelMonitorRepository) ListLatestPerModel(ctx context.Context, monit
out := make([]*service.ChannelMonitorLatest, 0)
for rows.Next() {
l := &service.ChannelMonitorLatest{}
- var latency sql.NullInt64
- if err := rows.Scan(&l.Model, &l.Status, &latency, &l.CheckedAt); err != nil {
+ var latency, ping sql.NullInt64
+ if err := rows.Scan(&l.Model, &l.Status, &latency, &ping, &l.CheckedAt); err != nil {
return nil, fmt.Errorf("scan latest row: %w", err)
}
- if latency.Valid {
- v := int(latency.Int64)
- l.LatencyMs = &v
- }
+ assignNullInt(&l.LatencyMs, latency)
+ assignNullInt(&l.PingLatencyMs, ping)
out = append(out, l)
}
return out, rows.Err()
}
+// assignNullInt 把 sql.NullInt64 解包到 *int 指针目标(valid 才分配新 int)。
+// 集中实现避免 latency / ping 两处重复 if latency.Valid { v := int(...) ... } 模板。
+func assignNullInt(dst **int, n sql.NullInt64) {
+ if !n.Valid {
+ return
+ }
+ v := int(n.Int64)
+ *dst = &v
+}
+
// ComputeAvailability 计算指定窗口内每个模型的可用率与平均延迟。
// "可用" = status IN (operational, degraded)。
func (r *channelMonitorRepository) ComputeAvailability(ctx context.Context, monitorID int64, windowDays int) ([]*service.ChannelMonitorAvailability, error) {
@@ -338,7 +346,7 @@ func (r *channelMonitorRepository) ListLatestForMonitorIDs(ctx context.Context,
}
const q = `
SELECT DISTINCT ON (monitor_id, model)
- monitor_id, model, status, latency_ms, checked_at
+ monitor_id, model, status, latency_ms, ping_latency_ms, checked_at
FROM channel_monitor_histories
WHERE monitor_id = ANY($1)
ORDER BY monitor_id, model, checked_at DESC
@@ -352,14 +360,12 @@ func (r *channelMonitorRepository) ListLatestForMonitorIDs(ctx context.Context,
for rows.Next() {
var monitorID int64
l := &service.ChannelMonitorLatest{}
- var latency sql.NullInt64
- if err := rows.Scan(&monitorID, &l.Model, &l.Status, &latency, &l.CheckedAt); err != nil {
+ var latency, ping sql.NullInt64
+ if err := rows.Scan(&monitorID, &l.Model, &l.Status, &latency, &ping, &l.CheckedAt); err != nil {
return nil, fmt.Errorf("scan latest batch row: %w", err)
}
- if latency.Valid {
- v := int(latency.Int64)
- l.LatencyMs = &v
- }
+ assignNullInt(&l.LatencyMs, latency)
+ assignNullInt(&l.PingLatencyMs, ping)
out[monitorID] = append(out[monitorID], l)
}
if err := rows.Err(); err != nil {
@@ -368,6 +374,107 @@ func (r *channelMonitorRepository) ListLatestForMonitorIDs(ctx context.Context,
return out, nil
}
+// ListRecentHistoryForMonitors 为多个 monitor 批量取各自"指定模型"最近 N 条历史(按 checked_at DESC,最新在前)。
+// primaryModels[monitorID] 指定该监控要过滤的模型名;monitor 不在 primaryModels 中的记录不返回。
+// 通过 CTE + unnest(两个 int8/text 数组) 构造 (monitor_id, model) 白名单,
+// 再用 ROW_NUMBER() OVER (PARTITION BY monitor_id) 取各自前 N 条。
+//
+// 返回值:map[monitorID] -> []*ChannelMonitorHistoryEntry(不含 message,减少网络开销)。
+// 空 ids / 空 primaryModels 返回空 map,不报错。
+func (r *channelMonitorRepository) ListRecentHistoryForMonitors(
+ ctx context.Context,
+ ids []int64,
+ primaryModels map[int64]string,
+ perMonitorLimit int,
+) (map[int64][]*service.ChannelMonitorHistoryEntry, error) {
+ out := make(map[int64][]*service.ChannelMonitorHistoryEntry, len(ids))
+ pairIDs, pairModels := buildMonitorModelPairs(ids, primaryModels)
+ if len(pairIDs) == 0 {
+ return out, nil
+ }
+ perMonitorLimit = clampTimelineLimit(perMonitorLimit)
+
+ const q = `
+ WITH targets AS (
+ SELECT unnest($1::bigint[]) AS monitor_id,
+ unnest($2::text[]) AS model
+ ),
+ ranked AS (
+ SELECT h.monitor_id,
+ h.status,
+ h.latency_ms,
+ h.ping_latency_ms,
+ h.checked_at,
+ ROW_NUMBER() OVER (PARTITION BY h.monitor_id ORDER BY h.checked_at DESC) AS rn
+ FROM channel_monitor_histories h
+ JOIN targets t
+ ON t.monitor_id = h.monitor_id AND t.model = h.model
+ )
+ SELECT monitor_id, status, latency_ms, ping_latency_ms, checked_at
+ FROM ranked
+ WHERE rn <= $3
+ ORDER BY monitor_id, checked_at DESC
+ `
+ rows, err := r.db.QueryContext(ctx, q, pq.Array(pairIDs), pq.Array(pairModels), perMonitorLimit)
+ if err != nil {
+ return nil, fmt.Errorf("query recent history batch: %w", err)
+ }
+ defer func() { _ = rows.Close() }()
+
+ for rows.Next() {
+ var monitorID int64
+ entry := &service.ChannelMonitorHistoryEntry{}
+ var latency, ping sql.NullInt64
+ if err := rows.Scan(&monitorID, &entry.Status, &latency, &ping, &entry.CheckedAt); err != nil {
+ return nil, fmt.Errorf("scan recent history row: %w", err)
+ }
+ assignNullInt(&entry.LatencyMs, latency)
+ assignNullInt(&entry.PingLatencyMs, ping)
+ out[monitorID] = append(out[monitorID], entry)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return out, nil
+}
+
+// buildMonitorModelPairs 基于 ids 过滤出有效的 (monitor_id, model) 对,model 为空时跳过。
+// 保证两个数组长度一致且一一对应,供 unnest 展开。
+func buildMonitorModelPairs(ids []int64, primaryModels map[int64]string) ([]int64, []string) {
+ if len(ids) == 0 || len(primaryModels) == 0 {
+ return nil, nil
+ }
+ pairIDs := make([]int64, 0, len(ids))
+ pairModels := make([]string, 0, len(ids))
+ for _, id := range ids {
+ model, ok := primaryModels[id]
+ if !ok || strings.TrimSpace(model) == "" {
+ continue
+ }
+ pairIDs = append(pairIDs, id)
+ pairModels = append(pairModels, model)
+ }
+ return pairIDs, pairModels
+}
+
+// timelineLimit* 批量 timeline 查询的 perMonitorLimit 夹紧范围。
+// 下限 1 表示至少返回最近一条;上限 200 控制单次响应体与 SQL 内存占用(ROW_NUMBER 窗口上限)。
+const (
+ timelineLimitMin = 1
+ timelineLimitMax = 200
+)
+
+// clampTimelineLimit 把 perMonitorLimit 夹紧到 [timelineLimitMin, timelineLimitMax],避免非法值或超大查询。
+func clampTimelineLimit(n int) int {
+ if n < timelineLimitMin {
+ return timelineLimitMin
+ }
+ if n > timelineLimitMax {
+ return timelineLimitMax
+ }
+ return n
+}
+
// ComputeAvailabilityForMonitors 一次性计算多个监控在某个窗口内的每模型可用率与平均延迟。
func (r *channelMonitorRepository) ComputeAvailabilityForMonitors(ctx context.Context, ids []int64, windowDays int) (map[int64][]*service.ChannelMonitorAvailability, error) {
out := make(map[int64][]*service.ChannelMonitorAvailability, len(ids))
diff --git a/backend/internal/service/channel_monitor_aggregator.go b/backend/internal/service/channel_monitor_aggregator.go
index 97015b40..09020f5f 100644
--- a/backend/internal/service/channel_monitor_aggregator.go
+++ b/backend/internal/service/channel_monitor_aggregator.go
@@ -49,7 +49,12 @@ func (s *ChannelMonitorService) BatchMonitorStatusSummary(
}
// ListUserView 用户只读视图:列出所有 enabled 监控的概览。
-// 使用批量聚合接口避免 N+1:1 次查 monitors,1 次查 latest(所有 monitor),1 次查 availability。
+// 使用批量聚合接口避免 N+1:
+//
+// 1 次查 monitors;
+// 1 次批量 latest(含 ping_latency_ms);
+// 1 次批量 7d availability;
+// 1 次批量 timeline(主模型最近 N 条)。
func (s *ChannelMonitorService) ListUserView(ctx context.Context) ([]*UserMonitorView, error) {
monitors, err := s.repo.ListEnabled(ctx)
if err != nil {
@@ -59,6 +64,21 @@ func (s *ChannelMonitorService) ListUserView(ctx context.Context) ([]*UserMonito
return []*UserMonitorView{}, nil
}
+ ids, primaryByID, extrasByID := collectMonitorIndexes(monitors)
+ summaries := s.BatchMonitorStatusSummary(ctx, ids, primaryByID, extrasByID)
+ latestMap := s.batchLatest(ctx, ids)
+ timelineMap := s.batchTimeline(ctx, ids, primaryByID)
+
+ views := make([]*UserMonitorView, 0, len(monitors))
+ for _, m := range monitors {
+ primaryLatest := pickLatest(latestMap[m.ID], m.PrimaryModel)
+ views = append(views, buildUserViewFromSummary(m, summaries[m.ID], primaryLatest, timelineMap[m.ID]))
+ }
+ return views, nil
+}
+
+// collectMonitorIndexes 把 monitors 列表按 ID 展开为聚合查询所需的三个索引结构。
+func collectMonitorIndexes(monitors []*ChannelMonitor) ([]int64, map[int64]string, map[int64][]string) {
ids := make([]int64, 0, len(monitors))
primaryByID := make(map[int64]string, len(monitors))
extrasByID := make(map[int64][]string, len(monitors))
@@ -67,14 +87,44 @@ func (s *ChannelMonitorService) ListUserView(ctx context.Context) ([]*UserMonito
primaryByID[m.ID] = m.PrimaryModel
extrasByID[m.ID] = m.ExtraModels
}
- summaries := s.BatchMonitorStatusSummary(ctx, ids, primaryByID, extrasByID)
+ return ids, primaryByID, extrasByID
+}
- views := make([]*UserMonitorView, 0, len(monitors))
- for _, m := range monitors {
- summary := summaries[m.ID]
- views = append(views, buildUserViewFromSummary(m, summary))
+// batchLatest 批量取 latest per model,失败仅日志(与现有 BatchMonitorStatusSummary 一致,不阻断列表渲染)。
+func (s *ChannelMonitorService) batchLatest(ctx context.Context, ids []int64) map[int64][]*ChannelMonitorLatest {
+ latestMap, err := s.repo.ListLatestForMonitorIDs(ctx, ids)
+ if err != nil {
+ slog.Warn("channel_monitor: user view batch latest failed", "error", err)
+ return map[int64][]*ChannelMonitorLatest{}
}
- return views, nil
+ return latestMap
+}
+
+// batchTimeline 批量取每个 monitor 主模型最近 monitorTimelineMaxPoints 条历史。
+func (s *ChannelMonitorService) batchTimeline(
+ ctx context.Context,
+ ids []int64,
+ primaryByID map[int64]string,
+) map[int64][]*ChannelMonitorHistoryEntry {
+ timelineMap, err := s.repo.ListRecentHistoryForMonitors(ctx, ids, primaryByID, monitorTimelineMaxPoints)
+ if err != nil {
+ slog.Warn("channel_monitor: user view batch timeline failed", "error", err)
+ return map[int64][]*ChannelMonitorHistoryEntry{}
+ }
+ return timelineMap
+}
+
+// pickLatest 从 latest 切片中挑出指定 model 对应项,未命中返回 nil。
+func pickLatest(rows []*ChannelMonitorLatest, model string) *ChannelMonitorLatest {
+ if model == "" {
+ return nil
+ }
+ for _, r := range rows {
+ if r.Model == model {
+ return r
+ }
+ }
+ return nil
}
// GetUserDetail 用户只读视图:单个监控详情(每个模型 7d/15d/30d 可用率与平均延迟)。
@@ -170,9 +220,15 @@ func buildStatusSummary(
return summary
}
-// buildUserViewFromSummary 用预聚合好的 MonitorStatusSummary 装填 UserMonitorView(无 IO)。
-func buildUserViewFromSummary(m *ChannelMonitor, summary MonitorStatusSummary) *UserMonitorView {
- return &UserMonitorView{
+// buildUserViewFromSummary 用预聚合好的 MonitorStatusSummary + 主模型 latest + timeline 装填 UserMonitorView(无 IO)。
+// primaryLatest 可能为 nil(该监控尚无历史);timelineEntries 可能为空。
+func buildUserViewFromSummary(
+ m *ChannelMonitor,
+ summary MonitorStatusSummary,
+ primaryLatest *ChannelMonitorLatest,
+ timelineEntries []*ChannelMonitorHistoryEntry,
+) *UserMonitorView {
+ view := &UserMonitorView{
ID: m.ID,
Name: m.Name,
Provider: m.Provider,
@@ -182,7 +238,26 @@ func buildUserViewFromSummary(m *ChannelMonitor, summary MonitorStatusSummary) *
PrimaryLatencyMs: summary.PrimaryLatencyMs,
Availability7d: summary.Availability7d,
ExtraModels: summary.ExtraModels,
+ Timeline: buildTimelinePoints(timelineEntries),
+ }
+ if primaryLatest != nil {
+ view.PrimaryPingLatencyMs = primaryLatest.PingLatencyMs
}
+ return view
+}
+
+// buildTimelinePoints 把 history entry 裁剪为 timeline 点(去除 message/ID/Model,减小响应体)。
+func buildTimelinePoints(entries []*ChannelMonitorHistoryEntry) []UserMonitorTimelinePoint {
+ out := make([]UserMonitorTimelinePoint, 0, len(entries))
+ for _, e := range entries {
+ out = append(out, UserMonitorTimelinePoint{
+ Status: e.Status,
+ LatencyMs: e.LatencyMs,
+ PingLatencyMs: e.PingLatencyMs,
+ CheckedAt: e.CheckedAt,
+ })
+ }
+ return out
}
// mergeModelDetails 合并 latest + availability 三个窗口为 ModelDetail 列表。
diff --git a/backend/internal/service/channel_monitor_const.go b/backend/internal/service/channel_monitor_const.go
index b4c02bcb..7255e4be 100644
--- a/backend/internal/service/channel_monitor_const.go
+++ b/backend/internal/service/channel_monitor_const.go
@@ -65,6 +65,9 @@ const (
// MonitorHistoryMaxLimit 历史查询最大返回条数(handler 层共享)。
MonitorHistoryMaxLimit = 1000
+ // monitorTimelineMaxPoints 用户视图 timeline 每个监控最多返回的历史点数。
+ monitorTimelineMaxPoints = 60
+
// monitorEndpointResolveTimeout validateEndpoint 解析 hostname 的最长耗时。
monitorEndpointResolveTimeout = 5 * time.Second
diff --git a/backend/internal/service/channel_monitor_service.go b/backend/internal/service/channel_monitor_service.go
index b179e50c..957ace15 100644
--- a/backend/internal/service/channel_monitor_service.go
+++ b/backend/internal/service/channel_monitor_service.go
@@ -38,6 +38,9 @@ type ChannelMonitorRepository interface {
// 批量聚合(admin/user list 用,避免 N+1)
ListLatestForMonitorIDs(ctx context.Context, ids []int64) (map[int64][]*ChannelMonitorLatest, error)
ComputeAvailabilityForMonitors(ctx context.Context, ids []int64, windowDays int) (map[int64][]*ChannelMonitorAvailability, error)
+ // ListRecentHistoryForMonitors 批量取多个 monitor 各自主模型(primaryModels[monitorID])最近 perMonitorLimit 条历史。
+ // 返回的 entry 已按 checked_at DESC 排序(最新在前),不含 message 字段。
+ ListRecentHistoryForMonitors(ctx context.Context, ids []int64, primaryModels map[int64]string, perMonitorLimit int) (map[int64][]*ChannelMonitorHistoryEntry, error)
}
// ChannelMonitorService 渠道监控管理服务。
diff --git a/backend/internal/service/channel_monitor_types.go b/backend/internal/service/channel_monitor_types.go
index 4b34d8af..739c82fb 100644
--- a/backend/internal/service/channel_monitor_types.go
+++ b/backend/internal/service/channel_monitor_types.go
@@ -72,15 +72,25 @@ type CheckResult struct {
// UserMonitorView 用户只读视图:监控概览(含主模型最近状态 + 7d 可用率 + 附加模型最近状态)。
type UserMonitorView struct {
- ID int64
- Name string
- Provider string
- GroupName string
- PrimaryModel string
- PrimaryStatus string
- PrimaryLatencyMs *int
- Availability7d float64 // 0-100
- ExtraModels []ExtraModelStatus
+ ID int64
+ Name string
+ Provider string
+ GroupName string
+ PrimaryModel string
+ PrimaryStatus string
+ PrimaryLatencyMs *int
+ PrimaryPingLatencyMs *int // 主模型最近一次 ping 延迟
+ Availability7d float64 // 0-100
+ ExtraModels []ExtraModelStatus
+ Timeline []UserMonitorTimelinePoint // 主模型最近 N 个历史点(按 checked_at DESC,最新在前)
+}
+
+// UserMonitorTimelinePoint 用户视图 timeline 单点数据(去除 message 以减小响应体)。
+type UserMonitorTimelinePoint struct {
+ Status string `json:"status"`
+ LatencyMs *int `json:"latency_ms"`
+ PingLatencyMs *int `json:"ping_latency_ms"`
+ CheckedAt time.Time `json:"checked_at"`
}
// ExtraModelStatus 附加模型最近一次状态。
@@ -134,10 +144,11 @@ type ChannelMonitorHistoryEntry struct {
// ChannelMonitorLatest 最近一次检测的简明信息(用于 UserMonitorView 聚合)。
type ChannelMonitorLatest struct {
- Model string
- Status string
- LatencyMs *int
- CheckedAt time.Time
+ Model string
+ Status string
+ LatencyMs *int
+ PingLatencyMs *int
+ CheckedAt time.Time
}
// ChannelMonitorAvailability 单个模型在某窗口内的可用率与平均延迟(用于 UserMonitorDetail 聚合)。
diff --git a/frontend/src/api/channelMonitor.ts b/frontend/src/api/channelMonitor.ts
index c5481636..38dd0c99 100644
--- a/frontend/src/api/channelMonitor.ts
+++ b/frontend/src/api/channelMonitor.ts
@@ -14,6 +14,13 @@ export interface UserMonitorExtraModel {
latency_ms: number | null
}
+export interface MonitorTimelinePoint {
+ status: MonitorStatus
+ latency_ms: number | null
+ ping_latency_ms: number | null
+ checked_at: string
+}
+
export interface UserMonitorView {
id: number
name: string
@@ -22,8 +29,10 @@ export interface UserMonitorView {
primary_model: string
primary_status: MonitorStatus
primary_latency_ms: number | null
+ primary_ping_latency_ms: number | null
availability_7d: number
extra_models: UserMonitorExtraModel[]
+ timeline: MonitorTimelinePoint[]
}
export interface UserMonitorListResponse {
diff --git a/frontend/src/components/user/MonitorPrimaryModelCell.vue b/frontend/src/components/user/MonitorPrimaryModelCell.vue
deleted file mode 100644
index 32620b2a..00000000
--- a/frontend/src/components/user/MonitorPrimaryModelCell.vue
+++ /dev/null
@@ -1,71 +0,0 @@
-
-
-
{{ row.primary_model }}
-
-
-
- {{ statusLabel(row.primary_status) }}
-
-
-
-
- {{ row.primary_model }}
-
- {{ statusLabel(row.primary_status) }}
-
-
-
- {{ t('monitorCommon.extraModelsEmpty') }}
-
-
-
- {{ t('monitorCommon.extraModelsHeader') }}
-
-
-
-
- {{ t('channelStatus.detailColumns.model') }}
- {{ t('channelStatus.detailColumns.latestStatus') }}
- {{ t('channelStatus.detailColumns.latestLatency') }}
-
-
-
-
- {{ m.model }}
-
-
- {{ statusLabel(m.status) }}
-
-
- {{ formatLatency(m.latency_ms) }}
-
-
-
-
-
-
-
-
-
-
diff --git a/frontend/src/components/user/monitor/MonitorAvailabilityRow.vue b/frontend/src/components/user/monitor/MonitorAvailabilityRow.vue
new file mode 100644
index 00000000..34420c9d
--- /dev/null
+++ b/frontend/src/components/user/monitor/MonitorAvailabilityRow.vue
@@ -0,0 +1,49 @@
+
+
+
+ {{ windowLabel }}
+
+
+
+ {{ displayValue }}
+
+ %
+
+
+
+ {{ samplesLabel }}
+
+
+
+
diff --git a/frontend/src/components/user/monitor/MonitorCard.vue b/frontend/src/components/user/monitor/MonitorCard.vue
new file mode 100644
index 00000000..33742c6d
--- /dev/null
+++ b/frontend/src/components/user/monitor/MonitorCard.vue
@@ -0,0 +1,128 @@
+
+
+
+
+
+
+
+
+
+ {{ item.name }}
+
+
+
+ {{ providerLabel(item.provider) }}
+
+
+ {{ item.primary_model }}
+
+
+ {{ item.group_name }}
+
+
+
+
+ {{ statusLabel(item.primary_status) }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/frontend/src/components/user/monitor/MonitorCardGrid.vue b/frontend/src/components/user/monitor/MonitorCardGrid.vue
new file mode 100644
index 00000000..c7d24c01
--- /dev/null
+++ b/frontend/src/components/user/monitor/MonitorCardGrid.vue
@@ -0,0 +1,81 @@
+
+
+
+
+
diff --git a/frontend/src/components/user/monitor/MonitorHero.vue b/frontend/src/components/user/monitor/MonitorHero.vue
new file mode 100644
index 00000000..be5a96b8
--- /dev/null
+++ b/frontend/src/components/user/monitor/MonitorHero.vue
@@ -0,0 +1,133 @@
+
+
+
+ {{ t('channelStatus.hero.breadcrumb') }}
+
+
+
+
+ {{ t('channelStatus.hero.title') }}
+
+
+ {{ t('channelStatus.hero.subtitleZh') }}
+
+
+ {{ t('channelStatus.hero.subtitleEn') }}
+
+
+
+
+
+
+ {{ opt.label }}
+
+
+
+
+
+
+ {{ overallLabel }}
+
+
+
+
+
+
+
+ {{ updatedLabel }} · {{ t('monitorCommon.pollEvery', { n: intervalSeconds }) }}
+
+
+
+
+
+
+
diff --git a/frontend/src/components/user/monitor/MonitorMetricPair.vue b/frontend/src/components/user/monitor/MonitorMetricPair.vue
new file mode 100644
index 00000000..0f3fd3dc
--- /dev/null
+++ b/frontend/src/components/user/monitor/MonitorMetricPair.vue
@@ -0,0 +1,45 @@
+
+
+
+
+
+ {{ primaryLabel }}
+
+
+ {{ primaryValue }}{{ primaryUnit }}
+
+
+
+
+
+ {{ secondaryLabel }}
+
+
+ {{ secondaryValue }}{{ secondaryUnit }}
+
+
+
+
+
+
diff --git a/frontend/src/components/user/monitor/MonitorTimeline.vue b/frontend/src/components/user/monitor/MonitorTimeline.vue
new file mode 100644
index 00000000..b4d0c151
--- /dev/null
+++ b/frontend/src/components/user/monitor/MonitorTimeline.vue
@@ -0,0 +1,113 @@
+
+
+
+ {{ t('monitorCommon.history60pts', { n: length }) }}
+ {{ t('monitorCommon.nextUpdateIn', { n: countdownSeconds }) }}
+
+
+
+ {{ t('monitorCommon.maintenancePaused') }}
+
+
+
+
+ {{ t('monitorCommon.past') }}
+ {{ t('monitorCommon.now') }}
+
+
+
+
+
diff --git a/frontend/src/components/user/monitor/ProviderIcon.vue b/frontend/src/components/user/monitor/ProviderIcon.vue
new file mode 100644
index 00000000..20456a2c
--- /dev/null
+++ b/frontend/src/components/user/monitor/ProviderIcon.vue
@@ -0,0 +1,71 @@
+
+
+
+
+
+ {{ fallbackText }}
+
+
+
+
diff --git a/frontend/src/composables/useChannelMonitorFormat.ts b/frontend/src/composables/useChannelMonitorFormat.ts
index fbb310fa..7ffdaa42 100644
--- a/frontend/src/composables/useChannelMonitorFormat.ts
+++ b/frontend/src/composables/useChannelMonitorFormat.ts
@@ -4,6 +4,7 @@
* Centralises:
* - status / provider label + badge class lookups
* - latency / availability / percent number formatting
+ * - dashboard-style helpers (HSL for availability, provider gradient, relative time)
*
* i18n keys live under `monitorCommon.*` so admin and user views share the
* same translation source.
@@ -23,6 +24,11 @@ import {
const NEUTRAL_BADGE = 'bg-gray-100 text-gray-800 dark:bg-dark-700 dark:text-gray-300'
+/** Availability HSL hue multiplier: 0%=red(0) / 50%=yellow(60) / 100%=green(120). */
+const HSL_HUE_PER_PERCENT = 1.2
+const HSL_SATURATION = 72
+const HSL_LIGHTNESS = 42
+
export interface AvailabilityRow {
primary_status: MonitorStatus | ''
availability_7d: number | null | undefined
@@ -39,11 +45,11 @@ export function useChannelMonitorFormat() {
function statusBadgeClass(s: MonitorStatus | ''): string {
switch (s) {
case STATUS_OPERATIONAL:
- return 'bg-green-100 text-green-800 dark:bg-green-900/30 dark:text-green-300'
+ return 'bg-emerald-100 text-emerald-700 dark:bg-emerald-500/15 dark:text-emerald-300'
case STATUS_DEGRADED:
- return 'bg-yellow-100 text-yellow-800 dark:bg-yellow-900/30 dark:text-yellow-300'
+ return 'bg-amber-100 text-amber-700 dark:bg-amber-500/15 dark:text-amber-300'
case STATUS_FAILED:
- return 'bg-red-100 text-red-800 dark:bg-red-900/30 dark:text-red-300'
+ return 'bg-red-100 text-red-700 dark:bg-red-500/15 dark:text-red-300'
case STATUS_ERROR:
default:
return NEUTRAL_BADGE
@@ -60,11 +66,11 @@ export function useChannelMonitorFormat() {
function providerBadgeClass(p: Provider | string): string {
switch (p) {
case PROVIDER_OPENAI:
- return 'bg-green-100 text-green-800 dark:bg-green-900/30 dark:text-green-300'
+ return 'bg-emerald-100 text-emerald-700 dark:bg-emerald-500/15 dark:text-emerald-300'
case PROVIDER_ANTHROPIC:
- return 'bg-orange-100 text-orange-800 dark:bg-orange-900/30 dark:text-orange-300'
+ return 'bg-orange-100 text-orange-700 dark:bg-orange-500/15 dark:text-orange-300'
case PROVIDER_GEMINI:
- return 'bg-blue-100 text-blue-800 dark:bg-blue-900/30 dark:text-blue-300'
+ return 'bg-sky-100 text-sky-700 dark:bg-sky-500/15 dark:text-sky-300'
default:
return NEUTRAL_BADGE
}
@@ -85,6 +91,20 @@ export function useChannelMonitorFormat() {
return formatPercent(row.availability_7d)
}
+ function formatRelativeTime(iso: string | null | undefined): string {
+ if (!iso) return t('monitorCommon.latencyEmpty')
+ const ts = Date.parse(iso)
+ if (Number.isNaN(ts)) return t('monitorCommon.latencyEmpty')
+ const diffSec = Math.max(0, Math.floor((Date.now() - ts) / 1000))
+ if (diffSec < 60) return t('monitorCommon.relativeSecondsAgo', { n: diffSec })
+ const diffMin = Math.floor(diffSec / 60)
+ if (diffMin < 60) return t('monitorCommon.relativeMinutesAgo', { n: diffMin })
+ const diffHour = Math.floor(diffMin / 60)
+ if (diffHour < 24) return t('monitorCommon.relativeHoursAgo', { n: diffHour })
+ const diffDay = Math.floor(diffHour / 24)
+ return t('monitorCommon.relativeDaysAgo', { n: diffDay })
+ }
+
return {
statusLabel,
statusBadgeClass,
@@ -93,5 +113,33 @@ export function useChannelMonitorFormat() {
formatLatency,
formatPercent,
formatAvailability,
+ formatRelativeTime,
+ }
+}
+
+/**
+ * Map availability percent to an HSL colour (red -> yellow -> green).
+ * Returns undefined for null/NaN so callers can fall back to a neutral colour.
+ */
+export function hslForPct(pct: number | null | undefined): string | undefined {
+ if (pct === null || pct === undefined || Number.isNaN(pct)) return undefined
+ const clamped = Math.max(0, Math.min(100, pct))
+ const hue = clamped * HSL_HUE_PER_PERCENT
+ return `hsl(${hue} ${HSL_SATURATION}% ${HSL_LIGHTNESS}%)`
+}
+
+/**
+ * Tailwind gradient class for the provider icon tile background.
+ */
+export function providerGradient(provider: string): string {
+ switch (provider) {
+ case PROVIDER_OPENAI:
+ return 'bg-gradient-to-br from-emerald-50 to-emerald-100 dark:from-emerald-500/10 dark:to-emerald-500/20'
+ case PROVIDER_ANTHROPIC:
+ return 'bg-gradient-to-br from-orange-50 to-amber-100 dark:from-orange-500/10 dark:to-amber-500/20'
+ case PROVIDER_GEMINI:
+ return 'bg-gradient-to-br from-sky-50 to-indigo-100 dark:from-sky-500/10 dark:to-indigo-500/20'
+ default:
+ return 'bg-gradient-to-br from-gray-100 to-gray-200 dark:from-dark-700 dark:to-dark-600'
}
}
diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts
index 32fbce19..b95c8b44 100644
--- a/frontend/src/i18n/locales/en.ts
+++ b/frontend/src/i18n/locales/en.ts
@@ -867,7 +867,22 @@ export default {
},
extraModelsHeader: 'Extra Models',
extraModelsEmpty: 'No extra models',
- latencyEmpty: '-'
+ latencyEmpty: '-',
+ availabilityPrefix: 'Availability',
+ dialogLatency: 'Dialog Latency',
+ endpointPing: 'Endpoint PING',
+ history60pts: 'HISTORY ({n} PTS)',
+ nextUpdateIn: 'NEXT UPDATE IN {n}s',
+ past: 'PAST',
+ now: 'NOW',
+ maintenancePaused: 'Maintenance · timeline paused',
+ extraModelsCount: '+ {n} models',
+ pollEvery: '{n}s polling',
+ updatedAt: 'Updated {time}',
+ relativeSecondsAgo: '{n}s ago',
+ relativeMinutesAgo: '{n}m ago',
+ relativeHoursAgo: '{n}h ago',
+ relativeDaysAgo: '{n}d ago'
},
// Channel Status (user-facing read-only view)
@@ -880,6 +895,22 @@ export default {
detailLoadError: 'Failed to load channel detail',
detailTitle: 'Channel Detail',
closeDetail: 'Close',
+ hero: {
+ breadcrumb: 'CHANNEL · STATUS',
+ title: 'INTELLIGENCE MONITOR',
+ subtitleZh: 'Real-time tracking of availability, latency and status for leading AI endpoints.',
+ subtitleEn: 'Advanced performance metrics for next-gen intelligence.'
+ },
+ windowTab: {
+ '7d': '7 days',
+ '15d': '15 days',
+ '30d': '30 days'
+ },
+ overall: {
+ operational: 'OPERATIONAL',
+ degraded: 'DEGRADED',
+ unavailable: 'UNAVAILABLE'
+ },
columns: {
name: 'Name',
provider: 'Provider',
diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts
index dd3af363..54bc03c5 100644
--- a/frontend/src/i18n/locales/zh.ts
+++ b/frontend/src/i18n/locales/zh.ts
@@ -871,7 +871,22 @@ export default {
},
extraModelsHeader: '附加模型',
extraModelsEmpty: '无附加模型',
- latencyEmpty: '-'
+ latencyEmpty: '-',
+ availabilityPrefix: '可用性',
+ dialogLatency: '对话延迟',
+ endpointPing: '端点 PING',
+ history60pts: '近 {n} 次记录',
+ nextUpdateIn: '{n}s 后刷新',
+ past: 'PAST',
+ now: 'NOW',
+ maintenancePaused: '维护中 · 已暂停时间线采集',
+ extraModelsCount: '+ {n} 模型',
+ pollEvery: '{n}s 轮询',
+ updatedAt: '更新于 {time}',
+ relativeSecondsAgo: '{n} 秒前',
+ relativeMinutesAgo: '{n} 分钟前',
+ relativeHoursAgo: '{n} 小时前',
+ relativeDaysAgo: '{n} 天前'
},
// Channel Status (user-facing read-only view)
@@ -884,6 +899,22 @@ export default {
detailLoadError: '加载渠道详情失败',
detailTitle: '渠道详情',
closeDetail: '关闭',
+ hero: {
+ breadcrumb: '渠道 · 状态',
+ title: 'INTELLIGENCE MONITOR',
+ subtitleZh: '实时追踪各大 AI 模型对话接口的可用性、延迟与官方服务状态。',
+ subtitleEn: 'Advanced performance metrics for next-gen intelligence.'
+ },
+ windowTab: {
+ '7d': '7 天',
+ '15d': '15 天',
+ '30d': '30 天'
+ },
+ overall: {
+ operational: 'OPERATIONAL',
+ degraded: 'DEGRADED',
+ unavailable: 'UNAVAILABLE'
+ },
columns: {
name: '名称',
provider: '供应商',
diff --git a/frontend/src/views/user/ChannelStatusView.vue b/frontend/src/views/user/ChannelStatusView.vue
index 9f5fe8d1..af427cca 100644
--- a/frontend/src/views/user/ChannelStatusView.vue
+++ b/frontend/src/views/user/ChannelStatusView.vue
@@ -1,93 +1,23 @@
-
-
-
-
-
-
-
-
-
- {{ row.name }}
-
-
-
-
-
- {{ providerLabel(row.provider) }}
-
-
-
-
- {{ value || '-' }}
-
-
-
-
-
-
-
-
- {{ formatAvailability(row) }}
-
-
-
-
-
- {{ formatLatency(row.primary_latency_ms) }}
-
-
-
-
-
-
-
-
-
+
+
+
--
GitLab
From 0fa47f18ed556a49ec582669e85b6863cdcdb2fd Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 00:02:51 +0800
Subject: [PATCH 072/261] feat: complete pending oauth account creation UI
---
.../auth/PendingOAuthCreateAccountForm.vue | 212 ++++++++++++++++++
.../PendingOAuthCreateAccountForm.spec.ts | 80 +++++++
.../src/views/auth/LinuxDoCallbackView.vue | 55 ++---
frontend/src/views/auth/OidcCallbackView.vue | 55 ++---
.../src/views/auth/WechatCallbackView.vue | 55 ++---
.../__tests__/LinuxDoCallbackView.spec.ts | 43 +++-
.../auth/__tests__/OidcCallbackView.spec.ts | 43 +++-
.../auth/__tests__/WechatCallbackView.spec.ts | 42 +++-
8 files changed, 469 insertions(+), 116 deletions(-)
create mode 100644 frontend/src/components/auth/PendingOAuthCreateAccountForm.vue
create mode 100644 frontend/src/components/auth/__tests__/PendingOAuthCreateAccountForm.spec.ts
diff --git a/frontend/src/components/auth/PendingOAuthCreateAccountForm.vue b/frontend/src/components/auth/PendingOAuthCreateAccountForm.vue
new file mode 100644
index 00000000..39588a86
--- /dev/null
+++ b/frontend/src/components/auth/PendingOAuthCreateAccountForm.vue
@@ -0,0 +1,212 @@
+
+
+
+
+
+
+
diff --git a/frontend/src/components/auth/__tests__/PendingOAuthCreateAccountForm.spec.ts b/frontend/src/components/auth/__tests__/PendingOAuthCreateAccountForm.spec.ts
new file mode 100644
index 00000000..63aeebc6
--- /dev/null
+++ b/frontend/src/components/auth/__tests__/PendingOAuthCreateAccountForm.spec.ts
@@ -0,0 +1,80 @@
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+import { flushPromises, mount } from '@vue/test-utils'
+
+import PendingOAuthCreateAccountForm from '../PendingOAuthCreateAccountForm.vue'
+
+const sendVerifyCode = vi.fn()
+
+vi.mock('vue-i18n', async () => {
+ const actual = await vi.importActual('vue-i18n')
+ return {
+ ...actual,
+ useI18n: () => ({
+ t: (key: string) => key
+ })
+ }
+})
+
+vi.mock('@/api/auth', async () => {
+ const actual = await vi.importActual('@/api/auth')
+ return {
+ ...actual,
+ sendVerifyCode: (...args: any[]) => sendVerifyCode(...args)
+ }
+})
+
+describe('PendingOAuthCreateAccountForm', () => {
+ beforeEach(() => {
+ sendVerifyCode.mockReset()
+ })
+
+ it('emits trimmed email, password, and verify code on submit', async () => {
+ const wrapper = mount(PendingOAuthCreateAccountForm, {
+ props: {
+ providerName: 'LinuxDo',
+ testIdPrefix: 'linuxdo',
+ initialEmail: 'prefill@example.com',
+ isSubmitting: false
+ }
+ })
+
+ await wrapper.get('[data-testid="linuxdo-create-account-email"]').setValue(' user@example.com ')
+ await wrapper.get('[data-testid="linuxdo-create-account-password"]').setValue('secret-123')
+ await wrapper.get('[data-testid="linuxdo-create-account-verify-code"]').setValue(' 246810 ')
+ await wrapper.get('form').trigger('submit.prevent')
+
+ expect(wrapper.emitted('submit')).toEqual([
+ [
+ {
+ email: 'user@example.com',
+ password: 'secret-123',
+ verifyCode: '246810'
+ }
+ ]
+ ])
+ })
+
+ it('sends a verify code for the trimmed email value', async () => {
+ sendVerifyCode.mockResolvedValue({
+ message: 'sent',
+ countdown: 60
+ })
+
+ const wrapper = mount(PendingOAuthCreateAccountForm, {
+ props: {
+ providerName: 'LinuxDo',
+ testIdPrefix: 'linuxdo',
+ initialEmail: '',
+ isSubmitting: false
+ }
+ })
+
+ await wrapper.get('[data-testid="linuxdo-create-account-email"]').setValue(' user@example.com ')
+ await wrapper.get('[data-testid="linuxdo-create-account-send-code"]').trigger('click')
+ await flushPromises()
+
+ expect(sendVerifyCode).toHaveBeenCalledWith({
+ email: 'user@example.com'
+ })
+ })
+})
diff --git a/frontend/src/views/auth/LinuxDoCallbackView.vue b/frontend/src/views/auth/LinuxDoCallbackView.vue
index ce7f92ef..128410bc 100644
--- a/frontend/src/views/auth/LinuxDoCallbackView.vue
+++ b/frontend/src/views/auth/LinuxDoCallbackView.vue
@@ -113,37 +113,14 @@
Enter an email address to create your account and continue.
-
-
-
- {{ isSubmitting ? t('common.processing') : 'Create account' }}
-
-
- I already have an account
-
-
-
-
- {{ accountActionError }}
-
-
+
@@ -258,6 +235,9 @@ import { computed, onMounted, ref } from 'vue'
import { useRoute, useRouter } from 'vue-router'
import { useI18n } from 'vue-i18n'
import { AuthLayout } from '@/components/layout'
+import PendingOAuthCreateAccountForm, {
+ type PendingOAuthCreateAccountPayload
+} from '@/components/auth/PendingOAuthCreateAccountForm.vue'
import Icon from '@/components/icons/Icon.vue'
import { apiClient } from '@/api/client'
import { useAuthStore, useAppStore } from '@/stores'
@@ -432,9 +412,9 @@ function applyTotpChallenge(completion: LinuxDoPendingActionResponse): boolean {
return true
}
-function switchToBindLoginMode() {
+function switchToBindLoginMode(nextEmail?: string) {
pendingAccountAction.value = 'bind_login'
- bindLoginEmail.value = bindLoginEmail.value.trim() || pendingAccountEmail.value.trim()
+ bindLoginEmail.value = bindLoginEmail.value.trim() || nextEmail?.trim() || pendingAccountEmail.value.trim()
bindLoginPassword.value = ''
accountActionError.value = ''
canReturnToCreateAccount.value = true
@@ -533,15 +513,16 @@ async function handleContinueLogin() {
}
}
-async function handleCreateAccount() {
+async function handleCreateAccount(payload: PendingOAuthCreateAccountPayload) {
accountActionError.value = ''
- const email = pendingAccountEmail.value.trim()
- if (!email) return
+ if (!payload.email || !payload.password) return
isSubmitting.value = true
try {
const { data } = await apiClient.post('/auth/oauth/pending/create-account', {
- email,
+ email: payload.email,
+ password: payload.password,
+ verify_code: payload.verifyCode || undefined,
...serializeAdoptionDecision(currentAdoptionDecision())
})
await finalizePendingAccountResponse(data)
diff --git a/frontend/src/views/auth/OidcCallbackView.vue b/frontend/src/views/auth/OidcCallbackView.vue
index de3f2e40..344c3537 100644
--- a/frontend/src/views/auth/OidcCallbackView.vue
+++ b/frontend/src/views/auth/OidcCallbackView.vue
@@ -122,37 +122,14 @@
Enter an email address to create your account and continue.
-
-
-
- {{ isSubmitting ? t('common.processing') : 'Create account' }}
-
-
- I already have an account
-
-
-
-
- {{ accountActionError }}
-
-
+
@@ -267,6 +244,9 @@ import { computed, onMounted, ref } from 'vue'
import { useRoute, useRouter } from 'vue-router'
import { useI18n } from 'vue-i18n'
import { AuthLayout } from '@/components/layout'
+import PendingOAuthCreateAccountForm, {
+ type PendingOAuthCreateAccountPayload
+} from '@/components/auth/PendingOAuthCreateAccountForm.vue'
import Icon from '@/components/icons/Icon.vue'
import { apiClient } from '@/api/client'
import { useAuthStore, useAppStore } from '@/stores'
@@ -476,9 +456,9 @@ function applyTotpChallenge(completion: PendingOidcCompletion): boolean {
return true
}
-function switchToBindLoginMode() {
+function switchToBindLoginMode(nextEmail?: string) {
pendingAccountAction.value = 'bind_login'
- bindLoginEmail.value = bindLoginEmail.value.trim() || pendingAccountEmail.value.trim()
+ bindLoginEmail.value = bindLoginEmail.value.trim() || nextEmail?.trim() || pendingAccountEmail.value.trim()
bindLoginPassword.value = ''
accountActionError.value = ''
canReturnToCreateAccount.value = true
@@ -577,15 +557,16 @@ async function handleContinueLogin() {
}
}
-async function handleCreateAccount() {
+async function handleCreateAccount(payload: PendingOAuthCreateAccountPayload) {
accountActionError.value = ''
- const email = pendingAccountEmail.value.trim()
- if (!email) return
+ if (!payload.email || !payload.password) return
isSubmitting.value = true
try {
const { data } = await apiClient.post('/auth/oauth/pending/create-account', {
- email,
+ email: payload.email,
+ password: payload.password,
+ verify_code: payload.verifyCode || undefined,
...serializeAdoptionDecision(currentAdoptionDecision())
})
await finalizePendingAccountResponse(data)
diff --git a/frontend/src/views/auth/WechatCallbackView.vue b/frontend/src/views/auth/WechatCallbackView.vue
index e4dd6301..10b83b1c 100644
--- a/frontend/src/views/auth/WechatCallbackView.vue
+++ b/frontend/src/views/auth/WechatCallbackView.vue
@@ -160,37 +160,14 @@
Enter an email address to create your account and continue.
-
-
-
- {{ isSubmitting ? t('common.processing') : 'Create account' }}
-
-
- I already have an account
-
-
-
-
- {{ accountActionError }}
-
-
+
@@ -305,6 +282,9 @@ import { computed, onMounted, ref } from 'vue'
import { useRoute, useRouter } from 'vue-router'
import { useI18n } from 'vue-i18n'
import { AuthLayout } from '@/components/layout'
+import PendingOAuthCreateAccountForm, {
+ type PendingOAuthCreateAccountPayload
+} from '@/components/auth/PendingOAuthCreateAccountForm.vue'
import Icon from '@/components/icons/Icon.vue'
import { apiClient } from '@/api/client'
import { useAuthStore, useAppStore } from '@/stores'
@@ -575,9 +555,9 @@ function applyTotpChallenge(completion: PendingWeChatCompletion): boolean {
return true
}
-function switchToBindLoginMode() {
+function switchToBindLoginMode(nextEmail?: string) {
pendingAccountAction.value = 'bind_login'
- bindLoginEmail.value = bindLoginEmail.value.trim() || pendingAccountEmail.value.trim()
+ bindLoginEmail.value = bindLoginEmail.value.trim() || nextEmail?.trim() || pendingAccountEmail.value.trim()
bindLoginPassword.value = ''
accountActionError.value = ''
canReturnToCreateAccount.value = true
@@ -676,15 +656,16 @@ async function handleContinueLogin() {
}
}
-async function handleCreateAccount() {
+async function handleCreateAccount(payload: PendingOAuthCreateAccountPayload) {
accountActionError.value = ''
- const email = pendingAccountEmail.value.trim()
- if (!email) return
+ if (!payload.email || !payload.password) return
isSubmitting.value = true
try {
const { data } = await apiClient.post('/auth/oauth/pending/create-account', {
- email,
+ email: payload.email,
+ password: payload.password,
+ verify_code: payload.verifyCode || undefined,
...serializeAdoptionDecision(currentAdoptionDecision())
})
await finalizePendingAccountResponse(data)
diff --git a/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts b/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts
index 8b2482fa..b9930b70 100644
--- a/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts
+++ b/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts
@@ -11,6 +11,7 @@ const exchangePendingOAuthCompletion = vi.fn()
const completeLinuxDoOAuthRegistration = vi.fn()
const login2FA = vi.fn()
const apiClientPost = vi.fn()
+const sendVerifyCode = vi.fn()
vi.mock('vue-router', () => ({
useRoute: () => ({
@@ -53,7 +54,8 @@ vi.mock('@/api/auth', async () => {
...actual,
exchangePendingOAuthCompletion: (...args: any[]) => exchangePendingOAuthCompletion(...args),
completeLinuxDoOAuthRegistration: (...args: any[]) => completeLinuxDoOAuthRegistration(...args),
- login2FA: (...args: any[]) => login2FA(...args)
+ login2FA: (...args: any[]) => login2FA(...args),
+ sendVerifyCode: (...args: any[]) => sendVerifyCode(...args)
}
})
@@ -67,6 +69,7 @@ describe('LinuxDoCallbackView', () => {
completeLinuxDoOAuthRegistration.mockReset()
login2FA.mockReset()
apiClientPost.mockReset()
+ sendVerifyCode.mockReset()
})
it('does not send adoption decisions during the initial exchange', async () => {
@@ -251,7 +254,7 @@ describe('LinuxDoCallbackView', () => {
})
})
- it('collects email for pending oauth account creation and submits adoption decisions', async () => {
+ it('collects email, password, and verify code for pending oauth account creation and submits adoption decisions', async () => {
exchangePendingOAuthCompletion.mockResolvedValue({
error: 'email_required',
redirect: '/welcome',
@@ -286,11 +289,15 @@ describe('LinuxDoCallbackView', () => {
expect(checkboxes).toHaveLength(2)
await checkboxes[1].setValue(false)
await wrapper.get('[data-testid="linuxdo-create-account-email"]').setValue(' new@example.com ')
+ await wrapper.get('[data-testid="linuxdo-create-account-password"]').setValue('secret-123')
+ await wrapper.get('[data-testid="linuxdo-create-account-verify-code"]').setValue('246810')
await wrapper.get('[data-testid="linuxdo-create-account-submit"]').trigger('click')
await flushPromises()
expect(apiClientPost).toHaveBeenCalledWith('/auth/oauth/pending/create-account', {
email: 'new@example.com',
+ password: 'secret-123',
+ verify_code: '246810',
adopt_display_name: true,
adopt_avatar: false
})
@@ -298,6 +305,38 @@ describe('LinuxDoCallbackView', () => {
expect(replace).toHaveBeenCalledWith('/welcome')
})
+ it('sends a verify code for pending oauth account creation', async () => {
+ exchangePendingOAuthCompletion.mockResolvedValue({
+ error: 'email_required',
+ redirect: '/welcome'
+ })
+ sendVerifyCode.mockResolvedValue({
+ message: 'sent',
+ countdown: 60
+ })
+
+ const wrapper = mount(LinuxDoCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false
+ }
+ }
+ })
+
+ await flushPromises()
+
+ await wrapper.get('[data-testid="linuxdo-create-account-email"]').setValue(' new@example.com ')
+ await wrapper.get('[data-testid="linuxdo-create-account-send-code"]').trigger('click')
+ await flushPromises()
+
+ expect(sendVerifyCode).toHaveBeenCalledWith({
+ email: 'new@example.com'
+ })
+ })
+
it('shows bind-login form for existing account binding and submits credentials with adoption decisions', async () => {
exchangePendingOAuthCompletion.mockResolvedValue({
error: 'bind_login_required',
diff --git a/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts b/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts
index c7460d8f..cb28e283 100644
--- a/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts
+++ b/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts
@@ -12,6 +12,7 @@ const completeOIDCOAuthRegistration = vi.fn()
const getPublicSettings = vi.fn()
const login2FA = vi.fn()
const apiClientPost = vi.fn()
+const sendVerifyCode = vi.fn()
vi.mock('vue-router', () => ({
useRoute: () => ({
@@ -60,7 +61,8 @@ vi.mock('@/api/auth', async () => {
exchangePendingOAuthCompletion: (...args: any[]) => exchangePendingOAuthCompletion(...args),
completeOIDCOAuthRegistration: (...args: any[]) => completeOIDCOAuthRegistration(...args),
getPublicSettings: (...args: any[]) => getPublicSettings(...args),
- login2FA: (...args: any[]) => login2FA(...args)
+ login2FA: (...args: any[]) => login2FA(...args),
+ sendVerifyCode: (...args: any[]) => sendVerifyCode(...args)
}
})
@@ -75,6 +77,7 @@ describe('OidcCallbackView', () => {
getPublicSettings.mockReset()
login2FA.mockReset()
apiClientPost.mockReset()
+ sendVerifyCode.mockReset()
getPublicSettings.mockResolvedValue({
oidc_oauth_provider_name: 'ExampleID'
})
@@ -234,7 +237,7 @@ describe('OidcCallbackView', () => {
})
})
- it('collects email for pending oauth account creation and submits adoption decisions', async () => {
+ it('collects email, password, and verify code for pending oauth account creation and submits adoption decisions', async () => {
exchangePendingOAuthCompletion.mockResolvedValue({
error: 'email_required',
redirect: '/welcome',
@@ -269,11 +272,15 @@ describe('OidcCallbackView', () => {
expect(checkboxes).toHaveLength(2)
await checkboxes[1].setValue(false)
await wrapper.get('[data-testid="oidc-create-account-email"]').setValue(' new@example.com ')
+ await wrapper.get('[data-testid="oidc-create-account-password"]').setValue('secret-123')
+ await wrapper.get('[data-testid="oidc-create-account-verify-code"]').setValue('246810')
await wrapper.get('[data-testid="oidc-create-account-submit"]').trigger('click')
await flushPromises()
expect(apiClientPost).toHaveBeenCalledWith('/auth/oauth/pending/create-account', {
email: 'new@example.com',
+ password: 'secret-123',
+ verify_code: '246810',
adopt_display_name: true,
adopt_avatar: false
})
@@ -281,6 +288,38 @@ describe('OidcCallbackView', () => {
expect(replace).toHaveBeenCalledWith('/welcome')
})
+ it('sends a verify code for pending oauth account creation', async () => {
+ exchangePendingOAuthCompletion.mockResolvedValue({
+ error: 'email_required',
+ redirect: '/welcome'
+ })
+ sendVerifyCode.mockResolvedValue({
+ message: 'sent',
+ countdown: 60
+ })
+
+ const wrapper = mount(OidcCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false
+ }
+ }
+ })
+
+ await flushPromises()
+
+ await wrapper.get('[data-testid="oidc-create-account-email"]').setValue(' new@example.com ')
+ await wrapper.get('[data-testid="oidc-create-account-send-code"]').trigger('click')
+ await flushPromises()
+
+ expect(sendVerifyCode).toHaveBeenCalledWith({
+ email: 'new@example.com'
+ })
+ })
+
it('shows bind-login form for existing account binding and submits credentials with adoption decisions', async () => {
exchangePendingOAuthCompletion.mockResolvedValue({
error: 'adopt_existing_user_by_email',
diff --git a/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts
index c49d0243..aa673238 100644
--- a/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts
+++ b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts
@@ -7,6 +7,7 @@ const {
completeWeChatOAuthRegistrationMock,
login2FAMock,
apiClientPostMock,
+ sendVerifyCodeMock,
prepareOAuthBindAccessTokenCookieMock,
getAuthTokenMock,
replaceMock,
@@ -20,6 +21,7 @@ const {
completeWeChatOAuthRegistrationMock: vi.fn(),
login2FAMock: vi.fn(),
apiClientPostMock: vi.fn(),
+ sendVerifyCodeMock: vi.fn(),
prepareOAuthBindAccessTokenCookieMock: vi.fn(),
getAuthTokenMock: vi.fn(),
replaceMock: vi.fn(),
@@ -118,6 +120,7 @@ vi.mock('@/api/auth', async () => {
exchangePendingOAuthCompletion: (...args: any[]) => exchangePendingOAuthCompletionMock(...args),
completeWeChatOAuthRegistration: (...args: any[]) => completeWeChatOAuthRegistrationMock(...args),
login2FA: (...args: any[]) => login2FAMock(...args),
+ sendVerifyCode: (...args: any[]) => sendVerifyCodeMock(...args),
prepareOAuthBindAccessTokenCookie: (...args: any[]) => prepareOAuthBindAccessTokenCookieMock(...args),
getAuthToken: (...args: any[]) => getAuthTokenMock(...args),
}
@@ -129,6 +132,7 @@ describe('WechatCallbackView', () => {
completeWeChatOAuthRegistrationMock.mockReset()
login2FAMock.mockReset()
apiClientPostMock.mockReset()
+ sendVerifyCodeMock.mockReset()
replaceMock.mockReset()
setTokenMock.mockReset()
showSuccessMock.mockReset()
@@ -374,7 +378,7 @@ describe('WechatCallbackView', () => {
expect(locationState.current.href).toContain('mode=open')
})
- it('collects email for pending oauth account creation and submits adoption decisions', async () => {
+ it('collects email, password, and verify code for pending oauth account creation and submits adoption decisions', async () => {
exchangePendingOAuthCompletionMock.mockResolvedValue({
error: 'email_required',
redirect: '/welcome',
@@ -409,11 +413,15 @@ describe('WechatCallbackView', () => {
expect(checkboxes).toHaveLength(2)
await checkboxes[1].setValue(false)
await wrapper.get('[data-testid="wechat-create-account-email"]').setValue(' new@example.com ')
+ await wrapper.get('[data-testid="wechat-create-account-password"]').setValue('secret-123')
+ await wrapper.get('[data-testid="wechat-create-account-verify-code"]').setValue('246810')
await wrapper.get('[data-testid="wechat-create-account-submit"]').trigger('click')
await flushPromises()
expect(apiClientPostMock).toHaveBeenCalledWith('/auth/oauth/pending/create-account', {
email: 'new@example.com',
+ password: 'secret-123',
+ verify_code: '246810',
adopt_display_name: true,
adopt_avatar: false,
})
@@ -421,6 +429,38 @@ describe('WechatCallbackView', () => {
expect(replaceMock).toHaveBeenCalledWith('/welcome')
})
+ it('sends a verify code for pending oauth account creation', async () => {
+ exchangePendingOAuthCompletionMock.mockResolvedValue({
+ error: 'email_required',
+ redirect: '/welcome',
+ })
+ sendVerifyCodeMock.mockResolvedValue({
+ message: 'sent',
+ countdown: 60,
+ })
+
+ const wrapper = mount(WechatCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false,
+ },
+ },
+ })
+
+ await flushPromises()
+
+ await wrapper.get('[data-testid="wechat-create-account-email"]').setValue(' new@example.com ')
+ await wrapper.get('[data-testid="wechat-create-account-send-code"]').trigger('click')
+ await flushPromises()
+
+ expect(sendVerifyCodeMock).toHaveBeenCalledWith({
+ email: 'new@example.com',
+ })
+ })
+
it('shows bind-login form for existing account binding and submits credentials with adoption decisions', async () => {
exchangePendingOAuthCompletionMock.mockResolvedValue({
step: 'bind_login_required',
--
GitLab
From 4ebdfcd13a6966085fdbc5145d1ac93222d11b35 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 00:03:27 +0800
Subject: [PATCH 073/261] test(admin): constrain payment visible method sources
---
.../settings.paymentVisibleMethods.spec.ts | 63 +++
frontend/src/api/admin/settings.ts | 68 +++
frontend/src/views/admin/SettingsView.vue | 61 ++-
.../admin/__tests__/SettingsView.spec.ts | 452 ++++++++++++++++++
4 files changed, 631 insertions(+), 13 deletions(-)
create mode 100644 frontend/src/api/__tests__/settings.paymentVisibleMethods.spec.ts
create mode 100644 frontend/src/views/admin/__tests__/SettingsView.spec.ts
diff --git a/frontend/src/api/__tests__/settings.paymentVisibleMethods.spec.ts b/frontend/src/api/__tests__/settings.paymentVisibleMethods.spec.ts
new file mode 100644
index 00000000..3b1a373f
--- /dev/null
+++ b/frontend/src/api/__tests__/settings.paymentVisibleMethods.spec.ts
@@ -0,0 +1,63 @@
+import { describe, expect, it } from 'vitest'
+
+import {
+ getPaymentVisibleMethodSourceOptions,
+ normalizePaymentVisibleMethodSource,
+} from '@/api/admin/settings'
+
+describe('admin settings payment visible method helpers', () => {
+ it('normalizes aliases into canonical source keys per visible method', () => {
+ expect(normalizePaymentVisibleMethodSource('alipay', 'official')).toBe('official_alipay')
+ expect(normalizePaymentVisibleMethodSource('alipay', 'alipay_direct')).toBe('official_alipay')
+ expect(normalizePaymentVisibleMethodSource('alipay', 'easypay')).toBe('easypay_alipay')
+
+ expect(normalizePaymentVisibleMethodSource('wxpay', 'official')).toBe('official_wxpay')
+ expect(normalizePaymentVisibleMethodSource('wxpay', 'wechat')).toBe('official_wxpay')
+ expect(normalizePaymentVisibleMethodSource('wxpay', 'easypay')).toBe('easypay_wxpay')
+ })
+
+ it('rejects unknown or cross-method source values', () => {
+ expect(normalizePaymentVisibleMethodSource('alipay', 'official_wxpay')).toBe('')
+ expect(normalizePaymentVisibleMethodSource('wxpay', 'official_alipay')).toBe('')
+ expect(normalizePaymentVisibleMethodSource('alipay', 'unknown')).toBe('')
+ expect(normalizePaymentVisibleMethodSource('wxpay', null)).toBe('')
+ })
+
+ it('exposes method-scoped source options instead of arbitrary strings', () => {
+ expect(getPaymentVisibleMethodSourceOptions('alipay')).toEqual([
+ {
+ value: '',
+ labelZh: '自动路由',
+ labelEn: 'Automatic routing',
+ },
+ {
+ value: 'official_alipay',
+ labelZh: '支付宝官方',
+ labelEn: 'Official Alipay',
+ },
+ {
+ value: 'easypay_alipay',
+ labelZh: '易支付支付宝',
+ labelEn: 'EasyPay Alipay',
+ },
+ ])
+
+ expect(getPaymentVisibleMethodSourceOptions('wxpay')).toEqual([
+ {
+ value: '',
+ labelZh: '自动路由',
+ labelEn: 'Automatic routing',
+ },
+ {
+ value: 'official_wxpay',
+ labelZh: '微信官方',
+ labelEn: 'Official WeChat Pay',
+ },
+ {
+ value: 'easypay_wxpay',
+ labelZh: '易支付微信',
+ labelEn: 'EasyPay WeChat Pay',
+ },
+ ])
+ })
+})
diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts
index 8e182c1c..505fcdca 100644
--- a/frontend/src/api/admin/settings.ts
+++ b/frontend/src/api/admin/settings.ts
@@ -22,10 +22,60 @@ export interface AuthSourceDefaultsValue {
}
export type AuthSourceDefaultsState = Record
+export type PaymentVisibleMethod = 'alipay' | 'wxpay'
+export type PaymentVisibleMethodSource =
+ | ''
+ | 'official_alipay'
+ | 'easypay_alipay'
+ | 'official_wxpay'
+ | 'easypay_wxpay'
+
+export interface PaymentVisibleMethodSourceOption {
+ value: PaymentVisibleMethodSource
+ labelZh: string
+ labelEn: string
+}
const AUTH_SOURCE_TYPES: AuthSourceType[] = ['email', 'linuxdo', 'oidc', 'wechat']
const AUTH_SOURCE_DEFAULT_BALANCE = 0
const AUTH_SOURCE_DEFAULT_CONCURRENCY = 5
+const PAYMENT_VISIBLE_METHOD_SOURCE_OPTIONS: Record<
+ PaymentVisibleMethod,
+ PaymentVisibleMethodSourceOption[]
+> = {
+ alipay: [
+ { value: '', labelZh: '自动路由', labelEn: 'Automatic routing' },
+ { value: 'official_alipay', labelZh: '支付宝官方', labelEn: 'Official Alipay' },
+ { value: 'easypay_alipay', labelZh: '易支付支付宝', labelEn: 'EasyPay Alipay' },
+ ],
+ wxpay: [
+ { value: '', labelZh: '自动路由', labelEn: 'Automatic routing' },
+ { value: 'official_wxpay', labelZh: '微信官方', labelEn: 'Official WeChat Pay' },
+ { value: 'easypay_wxpay', labelZh: '易支付微信', labelEn: 'EasyPay WeChat Pay' },
+ ],
+}
+const PAYMENT_VISIBLE_METHOD_SOURCE_ALIASES: Record<
+ PaymentVisibleMethod,
+ Record
+> = {
+ alipay: {
+ official_alipay: 'official_alipay',
+ alipay: 'official_alipay',
+ alipay_direct: 'official_alipay',
+ official: 'official_alipay',
+ easypay_alipay: 'easypay_alipay',
+ easypay: 'easypay_alipay',
+ },
+ wxpay: {
+ official_wxpay: 'official_wxpay',
+ wxpay: 'official_wxpay',
+ wxpay_direct: 'official_wxpay',
+ wechat: 'official_wxpay',
+ official: 'official_wxpay',
+ easypay_wxpay: 'easypay_wxpay',
+ easypay: 'easypay_wxpay',
+ },
+}
export function normalizeDefaultSubscriptionSettings(
subscriptions: DefaultSubscriptionSetting[] | null | undefined
@@ -86,6 +136,24 @@ export function appendAuthSourceDefaultsToUpdateRequest(
return payload
}
+export function getPaymentVisibleMethodSourceOptions(
+ method: PaymentVisibleMethod
+): PaymentVisibleMethodSourceOption[] {
+ return PAYMENT_VISIBLE_METHOD_SOURCE_OPTIONS[method]
+}
+
+export function normalizePaymentVisibleMethodSource(
+ method: PaymentVisibleMethod,
+ source: unknown
+): PaymentVisibleMethodSource {
+ if (typeof source !== 'string') return ''
+
+ const normalized = source.trim().toLowerCase()
+ if (!normalized) return ''
+
+ return PAYMENT_VISIBLE_METHOD_SOURCE_ALIASES[method][normalized] ?? ''
+}
+
/**
* System settings interface
*/
diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue
index 0d23baa5..8a042e70 100644
--- a/frontend/src/views/admin/SettingsView.vue
+++ b/frontend/src/views/admin/SettingsView.vue
@@ -2717,20 +2717,19 @@
- {{ localText('来源键', 'Source key') }}
+ {{ localText('支付来源', 'Payment source') }}
-
{{
localText(
- '留空表示由后端使用默认来源;可填 easypay、alipay、wxpay 等来源标识。',
- 'Leave blank to let the backend decide. Typical values are easypay, alipay, or wxpay.'
+ '留空表示自动路由;仅允许当前系统支持的官方或易支付来源。',
+ 'Leave blank for automatic routing. Only supported official or EasyPay sources are allowed.'
)
}}
@@ -3117,11 +3116,14 @@ import { adminAPI } from '@/api'
import {
appendAuthSourceDefaultsToUpdateRequest,
buildAuthSourceDefaultsState,
+ getPaymentVisibleMethodSourceOptions,
+ normalizePaymentVisibleMethodSource,
normalizeDefaultSubscriptionSettings,
} from '@/api/admin/settings'
import type {
AuthSourceDefaultsState,
AuthSourceType,
+ PaymentVisibleMethod,
SystemSettings,
UpdateSettingsRequest,
DefaultSubscriptionSetting,
@@ -3429,12 +3431,23 @@ function getPaymentVisibleMethodSource(method: 'alipay' | 'wxpay'): string {
: form.payment_visible_method_wxpay_source
}
-function setPaymentVisibleMethodSource(method: 'alipay' | 'wxpay', source: string) {
+function getPaymentVisibleMethodSourceSelectOptions(method: PaymentVisibleMethod) {
+ return getPaymentVisibleMethodSourceOptions(method).map((option) => ({
+ value: option.value,
+ label: localText(option.labelZh, option.labelEn),
+ }))
+}
+
+function setPaymentVisibleMethodSource(
+ method: 'alipay' | 'wxpay',
+ source: string | number | boolean | null
+) {
+ const normalized = normalizePaymentVisibleMethodSource(method, source)
if (method === 'alipay') {
- form.payment_visible_method_alipay_source = source
+ form.payment_visible_method_alipay_source = normalized
return
}
- form.payment_visible_method_wxpay_source = source
+ form.payment_visible_method_wxpay_source = normalized
}
// Proxies for web search emulation ProxySelector
@@ -3805,6 +3818,14 @@ async function loadSettings() {
Object.assign(authSourceDefaults, buildAuthSourceDefaultsState(settings))
form.backend_mode_enabled = settings.backend_mode_enabled
form.default_subscriptions = normalizeDefaultSubscriptionSettings(settings.default_subscriptions)
+ form.payment_visible_method_alipay_source = normalizePaymentVisibleMethodSource(
+ 'alipay',
+ settings.payment_visible_method_alipay_source
+ )
+ form.payment_visible_method_wxpay_source = normalizePaymentVisibleMethodSource(
+ 'wxpay',
+ settings.payment_visible_method_wxpay_source
+ )
registrationEmailSuffixWhitelistTags.value = normalizeRegistrationEmailSuffixDomains(
settings.registration_email_suffix_whitelist
)
@@ -4070,8 +4091,14 @@ async function saveSettings() {
payment_cancel_rate_limit_window: Number(form.payment_cancel_rate_limit_window) || 1,
payment_cancel_rate_limit_unit: form.payment_cancel_rate_limit_unit,
payment_cancel_rate_limit_window_mode: form.payment_cancel_rate_limit_window_mode,
- payment_visible_method_alipay_source: form.payment_visible_method_alipay_source,
- payment_visible_method_wxpay_source: form.payment_visible_method_wxpay_source,
+ payment_visible_method_alipay_source: normalizePaymentVisibleMethodSource(
+ 'alipay',
+ form.payment_visible_method_alipay_source
+ ),
+ payment_visible_method_wxpay_source: normalizePaymentVisibleMethodSource(
+ 'wxpay',
+ form.payment_visible_method_wxpay_source
+ ),
payment_visible_method_alipay_enabled: form.payment_visible_method_alipay_enabled,
payment_visible_method_wxpay_enabled: form.payment_visible_method_wxpay_enabled,
openai_advanced_scheduler_enabled: form.openai_advanced_scheduler_enabled,
@@ -4092,6 +4119,14 @@ async function saveSettings() {
}
}
Object.assign(authSourceDefaults, buildAuthSourceDefaultsState(updated))
+ form.payment_visible_method_alipay_source = normalizePaymentVisibleMethodSource(
+ 'alipay',
+ updated.payment_visible_method_alipay_source
+ )
+ form.payment_visible_method_wxpay_source = normalizePaymentVisibleMethodSource(
+ 'wxpay',
+ updated.payment_visible_method_wxpay_source
+ )
registrationEmailSuffixWhitelistTags.value = normalizeRegistrationEmailSuffixDomains(
updated.registration_email_suffix_whitelist
)
diff --git a/frontend/src/views/admin/__tests__/SettingsView.spec.ts b/frontend/src/views/admin/__tests__/SettingsView.spec.ts
new file mode 100644
index 00000000..f20170e9
--- /dev/null
+++ b/frontend/src/views/admin/__tests__/SettingsView.spec.ts
@@ -0,0 +1,452 @@
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+import { defineComponent, h, ref } from 'vue'
+import { flushPromises, mount } from '@vue/test-utils'
+
+import SettingsView from '../SettingsView.vue'
+
+const {
+ getSettings,
+ updateSettings,
+ getWebSearchEmulationConfig,
+ updateWebSearchEmulationConfig,
+ getAdminApiKey,
+ getOverloadCooldownSettings,
+ getStreamTimeoutSettings,
+ getRectifierSettings,
+ getBetaPolicySettings,
+ getGroups,
+ listProxies,
+ getProviders,
+ fetchPublicSettings,
+ adminSettingsFetch,
+ showError,
+ showSuccess,
+} = vi.hoisted(() => ({
+ getSettings: vi.fn(),
+ updateSettings: vi.fn(),
+ getWebSearchEmulationConfig: vi.fn(),
+ updateWebSearchEmulationConfig: vi.fn(),
+ getAdminApiKey: vi.fn(),
+ getOverloadCooldownSettings: vi.fn(),
+ getStreamTimeoutSettings: vi.fn(),
+ getRectifierSettings: vi.fn(),
+ getBetaPolicySettings: vi.fn(),
+ getGroups: vi.fn(),
+ listProxies: vi.fn(),
+ getProviders: vi.fn(),
+ fetchPublicSettings: vi.fn(),
+ adminSettingsFetch: vi.fn(),
+ showError: vi.fn(),
+ showSuccess: vi.fn(),
+}))
+
+vi.mock('@/api', () => ({
+ adminAPI: {
+ settings: {
+ getSettings,
+ updateSettings,
+ getWebSearchEmulationConfig,
+ updateWebSearchEmulationConfig,
+ getAdminApiKey,
+ getOverloadCooldownSettings,
+ getStreamTimeoutSettings,
+ getRectifierSettings,
+ getBetaPolicySettings,
+ },
+ groups: {
+ getAll: getGroups,
+ },
+ proxies: {
+ list: listProxies,
+ },
+ payment: {
+ getProviders,
+ },
+ },
+}))
+
+vi.mock('@/stores', () => ({
+ useAppStore: () => ({
+ showError,
+ showSuccess,
+ showWarning: vi.fn(),
+ showInfo: vi.fn(),
+ fetchPublicSettings,
+ }),
+}))
+
+vi.mock('@/stores/adminSettings', () => ({
+ useAdminSettingsStore: () => ({
+ fetch: adminSettingsFetch,
+ }),
+}))
+
+vi.mock('@/composables/useClipboard', () => ({
+ useClipboard: () => ({
+ copyToClipboard: vi.fn(),
+ }),
+}))
+
+vi.mock('@/utils/apiError', () => ({
+ extractApiErrorMessage: () => 'error',
+}))
+
+vi.mock('vue-i18n', async () => {
+ const actual = await vi.importActual
('vue-i18n')
+ return {
+ ...actual,
+ useI18n: () => ({
+ t: (key: string) => key,
+ locale: ref('zh-CN'),
+ }),
+ }
+})
+
+const AppLayoutStub = { template: '
' }
+const ToggleStub = defineComponent({
+ props: {
+ modelValue: {
+ type: Boolean,
+ default: false,
+ },
+ },
+ emits: ['update:modelValue'],
+ setup(props, { emit }) {
+ return () =>
+ h('input', {
+ class: 'toggle-stub',
+ type: 'checkbox',
+ checked: props.modelValue,
+ onChange: (event: Event) => {
+ emit('update:modelValue', (event.target as HTMLInputElement).checked)
+ },
+ })
+ },
+})
+
+const SelectStub = defineComponent({
+ props: {
+ modelValue: {
+ type: [String, Number, Boolean, null],
+ default: '',
+ },
+ options: {
+ type: Array,
+ default: () => [],
+ },
+ placeholder: {
+ type: String,
+ default: '',
+ },
+ },
+ emits: ['update:modelValue', 'change'],
+ setup(props, { emit }) {
+ const onChange = (event: Event) => {
+ const target = event.target as HTMLSelectElement
+ emit('update:modelValue', target.value)
+ const option = (props.options as Array>).find(
+ (item) => String(item.value ?? '') === target.value
+ ) ?? null
+ emit('change', target.value, option)
+ }
+
+ return () =>
+ h(
+ 'select',
+ {
+ class: 'select-stub',
+ value: props.modelValue ?? '',
+ 'data-placeholder': props.placeholder,
+ onChange,
+ },
+ (props.options as Array>).map((option) =>
+ h(
+ 'option',
+ {
+ key: `${String(option.value ?? '')}:${String(option.label ?? '')}`,
+ value: option.value as string,
+ },
+ String(option.label ?? '')
+ )
+ )
+ )
+ },
+})
+
+const baseSettingsResponse = {
+ registration_enabled: true,
+ email_verify_enabled: false,
+ registration_email_suffix_whitelist: [],
+ promo_code_enabled: true,
+ invitation_code_enabled: false,
+ password_reset_enabled: false,
+ totp_enabled: false,
+ totp_encryption_key_configured: false,
+ default_balance: 0,
+ default_concurrency: 1,
+ default_subscriptions: [],
+ site_name: 'Sub2API',
+ site_logo: '',
+ site_subtitle: '',
+ api_base_url: '',
+ contact_info: '',
+ doc_url: '',
+ home_content: '',
+ hide_ccs_import_button: false,
+ table_default_page_size: 20,
+ table_page_size_options: [10, 20, 50, 100],
+ backend_mode_enabled: false,
+ custom_menu_items: [],
+ custom_endpoints: [],
+ frontend_url: '',
+ smtp_host: '',
+ smtp_port: 587,
+ smtp_username: '',
+ smtp_password_configured: false,
+ smtp_from_email: '',
+ smtp_from_name: '',
+ smtp_use_tls: true,
+ turnstile_enabled: false,
+ turnstile_site_key: '',
+ turnstile_secret_key_configured: false,
+ linuxdo_connect_enabled: false,
+ linuxdo_connect_client_id: '',
+ linuxdo_connect_client_secret_configured: false,
+ linuxdo_connect_redirect_url: '',
+ oidc_connect_enabled: false,
+ oidc_connect_provider_name: 'OIDC',
+ oidc_connect_client_id: '',
+ oidc_connect_client_secret_configured: false,
+ oidc_connect_issuer_url: '',
+ oidc_connect_discovery_url: '',
+ oidc_connect_authorize_url: '',
+ oidc_connect_token_url: '',
+ oidc_connect_userinfo_url: '',
+ oidc_connect_jwks_url: '',
+ oidc_connect_scopes: 'openid email profile',
+ oidc_connect_redirect_url: '',
+ oidc_connect_frontend_redirect_url: '/auth/oidc/callback',
+ oidc_connect_token_auth_method: 'client_secret_post',
+ oidc_connect_use_pkce: true,
+ oidc_connect_validate_id_token: true,
+ oidc_connect_allowed_signing_algs: 'RS256,ES256,PS256',
+ oidc_connect_clock_skew_seconds: 120,
+ oidc_connect_require_email_verified: false,
+ oidc_connect_userinfo_email_path: '',
+ oidc_connect_userinfo_id_path: '',
+ oidc_connect_userinfo_username_path: '',
+ enable_model_fallback: false,
+ fallback_model_anthropic: '',
+ fallback_model_openai: '',
+ fallback_model_gemini: '',
+ fallback_model_antigravity: '',
+ enable_identity_patch: false,
+ identity_patch_prompt: '',
+ ops_monitoring_enabled: false,
+ ops_realtime_monitoring_enabled: false,
+ ops_query_mode_default: 'auto',
+ ops_metrics_interval_seconds: 60,
+ min_claude_code_version: '',
+ max_claude_code_version: '',
+ allow_ungrouped_key_scheduling: false,
+ enable_fingerprint_unification: true,
+ enable_metadata_passthrough: false,
+ enable_cch_signing: false,
+ payment_enabled: true,
+ payment_min_amount: 1,
+ payment_max_amount: 10000,
+ payment_daily_limit: 50000,
+ payment_order_timeout_minutes: 30,
+ payment_max_pending_orders: 3,
+ payment_enabled_types: [],
+ payment_balance_disabled: false,
+ payment_balance_recharge_multiplier: 1,
+ payment_recharge_fee_rate: 0,
+ payment_load_balance_strategy: 'round-robin',
+ payment_product_name_prefix: '',
+ payment_product_name_suffix: '',
+ payment_help_image_url: '',
+ payment_help_text: '',
+ payment_cancel_rate_limit_enabled: false,
+ payment_cancel_rate_limit_max: 10,
+ payment_cancel_rate_limit_window: 1,
+ payment_cancel_rate_limit_unit: 'day',
+ payment_cancel_rate_limit_window_mode: 'rolling',
+ payment_visible_method_alipay_source: 'alipay_direct',
+ payment_visible_method_wxpay_source: 'invalid-source',
+ payment_visible_method_alipay_enabled: true,
+ payment_visible_method_wxpay_enabled: true,
+ openai_advanced_scheduler_enabled: false,
+ balance_low_notify_enabled: false,
+ balance_low_notify_threshold: 0,
+ balance_low_notify_recharge_url: '',
+ account_quota_notify_enabled: false,
+ account_quota_notify_emails: [],
+}
+
+function mountView() {
+ return mount(SettingsView, {
+ global: {
+ stubs: {
+ AppLayout: AppLayoutStub,
+ Select: SelectStub,
+ Toggle: ToggleStub,
+ Icon: true,
+ ConfirmDialog: true,
+ PaymentProviderList: true,
+ PaymentProviderDialog: true,
+ GroupBadge: true,
+ GroupOptionItem: true,
+ ProxySelector: true,
+ ImageUpload: true,
+ BackupSettings: true,
+ },
+ },
+ })
+}
+
+async function openPaymentTab(wrapper: ReturnType) {
+ const paymentTabButton = wrapper
+ .findAll('button')
+ .find((node) => node.text().includes('admin.settings.tabs.payment'))
+
+ expect(paymentTabButton).toBeDefined()
+ await paymentTabButton?.trigger('click')
+ await flushPromises()
+}
+
+describe('admin SettingsView payment visible method controls', () => {
+ beforeEach(() => {
+ getSettings.mockReset()
+ updateSettings.mockReset()
+ getWebSearchEmulationConfig.mockReset()
+ updateWebSearchEmulationConfig.mockReset()
+ getAdminApiKey.mockReset()
+ getOverloadCooldownSettings.mockReset()
+ getStreamTimeoutSettings.mockReset()
+ getRectifierSettings.mockReset()
+ getBetaPolicySettings.mockReset()
+ getGroups.mockReset()
+ listProxies.mockReset()
+ getProviders.mockReset()
+ fetchPublicSettings.mockReset()
+ adminSettingsFetch.mockReset()
+ showError.mockReset()
+ showSuccess.mockReset()
+
+ getSettings.mockResolvedValue({ ...baseSettingsResponse })
+ updateSettings.mockImplementation(async (payload) => ({
+ ...baseSettingsResponse,
+ ...payload,
+ }))
+ getWebSearchEmulationConfig.mockResolvedValue({
+ enabled: false,
+ providers: [],
+ })
+ updateWebSearchEmulationConfig.mockResolvedValue({
+ enabled: false,
+ providers: [],
+ })
+ getAdminApiKey.mockResolvedValue({
+ exists: false,
+ masked_key: '',
+ })
+ getOverloadCooldownSettings.mockResolvedValue({
+ enabled: true,
+ cooldown_minutes: 10,
+ })
+ getStreamTimeoutSettings.mockResolvedValue({
+ enabled: true,
+ action: 'temp_unsched',
+ temp_unsched_minutes: 5,
+ threshold_count: 3,
+ threshold_window_minutes: 10,
+ })
+ getRectifierSettings.mockResolvedValue({
+ enabled: true,
+ thinking_signature_enabled: true,
+ thinking_budget_enabled: true,
+ apikey_signature_enabled: false,
+ apikey_signature_patterns: [],
+ })
+ getBetaPolicySettings.mockResolvedValue({
+ rules: [],
+ })
+ getGroups.mockResolvedValue([])
+ listProxies.mockResolvedValue({
+ items: [],
+ })
+ getProviders.mockResolvedValue({
+ data: [],
+ })
+ fetchPublicSettings.mockResolvedValue(undefined)
+ adminSettingsFetch.mockResolvedValue(undefined)
+ })
+
+ it('loads canonical source options and normalizes existing values', async () => {
+ const wrapper = mountView()
+
+ await flushPromises()
+ await openPaymentTab(wrapper)
+
+ const paymentSourceSelects = wrapper
+ .findAll('select.select-stub')
+ .filter((node) => ['alipay', 'wxpay'].includes(node.attributes('data-placeholder')))
+
+ expect(paymentSourceSelects).toHaveLength(2)
+
+ const alipaySelect = paymentSourceSelects.find(
+ (node) => node.attributes('data-placeholder') === 'alipay'
+ )
+ const wxpaySelect = paymentSourceSelects.find(
+ (node) => node.attributes('data-placeholder') === 'wxpay'
+ )
+
+ expect(alipaySelect?.element.value).toBe('official_alipay')
+ expect(alipaySelect?.findAll('option').map((option) => option.element.value)).toEqual([
+ '',
+ 'official_alipay',
+ 'easypay_alipay',
+ ])
+
+ expect(wxpaySelect?.element.value).toBe('')
+ expect(wxpaySelect?.findAll('option').map((option) => option.element.value)).toEqual([
+ '',
+ 'official_wxpay',
+ 'easypay_wxpay',
+ ])
+ })
+
+ it('saves canonical source keys selected from the dropdowns', async () => {
+ const wrapper = mountView()
+
+ await flushPromises()
+ await openPaymentTab(wrapper)
+
+ const paymentSourceSelects = wrapper
+ .findAll('select.select-stub')
+ .filter((node) => ['alipay', 'wxpay'].includes(node.attributes('data-placeholder')))
+
+ const alipaySelect = paymentSourceSelects.find(
+ (node) => node.attributes('data-placeholder') === 'alipay'
+ )
+ const wxpaySelect = paymentSourceSelects.find(
+ (node) => node.attributes('data-placeholder') === 'wxpay'
+ )
+
+ await alipaySelect?.setValue('easypay_alipay')
+ await wxpaySelect?.setValue('official_wxpay')
+ await wrapper.find('form').trigger('submit.prevent')
+ await flushPromises()
+
+ expect(updateSettings).toHaveBeenCalledTimes(1)
+ expect(updateSettings).toHaveBeenCalledWith(
+ expect.objectContaining({
+ payment_visible_method_alipay_source: 'easypay_alipay',
+ payment_visible_method_wxpay_source: 'official_wxpay',
+ payment_visible_method_alipay_enabled: true,
+ payment_visible_method_wxpay_enabled: true,
+ })
+ )
+ })
+})
--
GitLab
From f83fd59dcadf9348c97620235a4f0cd66241f48d Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 00:05:09 +0800
Subject: [PATCH 074/261] Refine payment UX for wallet flows
---
backend/internal/payment/provider/alipay.go | 33 +++--
.../internal/payment/provider/alipay_test.go | 113 ++++++++++++++++++
frontend/src/i18n/locales/en.ts | 12 ++
frontend/src/i18n/locales/zh.ts | 12 ++
frontend/src/views/user/PaymentResultView.vue | 9 +-
frontend/src/views/user/PaymentView.vue | 33 ++++-
.../user/__tests__/PaymentResultView.spec.ts | 25 ++++
.../views/user/__tests__/paymentUx.spec.ts | 49 ++++++++
frontend/src/views/user/paymentUx.ts | 105 ++++++++++++++++
9 files changed, 373 insertions(+), 18 deletions(-)
create mode 100644 frontend/src/views/user/__tests__/paymentUx.spec.ts
create mode 100644 frontend/src/views/user/paymentUx.ts
diff --git a/backend/internal/payment/provider/alipay.go b/backend/internal/payment/provider/alipay.go
index af8a90c6..4f87e5a7 100644
--- a/backend/internal/payment/provider/alipay.go
+++ b/backend/internal/payment/provider/alipay.go
@@ -26,6 +26,18 @@ const (
alipayRefundSuffix = "-refund"
)
+var (
+ alipayTradeWapPay = func(client *alipay.Client, param alipay.TradeWapPay) (*url.URL, error) {
+ return client.TradeWapPay(param)
+ }
+ alipayTradePagePay = func(client *alipay.Client, param alipay.TradePagePay) (*url.URL, error) {
+ return client.TradePagePay(param)
+ }
+ alipayTradePreCreate = func(ctx context.Context, client *alipay.Client, param alipay.TradePreCreate) (*alipay.TradePreCreateRsp, error) {
+ return client.TradePreCreate(ctx, param)
+ }
+)
+
// Alipay implements payment.Provider and payment.CancelableProvider using the smartwalle/alipay SDK.
type Alipay struct {
instanceID string
@@ -80,7 +92,7 @@ func (a *Alipay) SupportedTypes() []payment.PaymentType {
}
// CreatePayment creates an Alipay payment page URL.
-func (a *Alipay) CreatePayment(_ context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
+func (a *Alipay) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
client, err := a.getClient()
if err != nil {
return nil, err
@@ -96,12 +108,12 @@ func (a *Alipay) CreatePayment(_ context.Context, req payment.CreatePaymentReque
}
if req.IsMobile {
- return a.createTrade(client, req, notifyURL, returnURL, true)
+ return a.createTrade(ctx, client, req, notifyURL, returnURL, true)
}
- return a.createTrade(client, req, notifyURL, returnURL, false)
+ return a.createTrade(ctx, client, req, notifyURL, returnURL, false)
}
-func (a *Alipay) createTrade(client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string, isMobile bool) (*payment.CreatePaymentResponse, error) {
+func (a *Alipay) createTrade(ctx context.Context, client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string, isMobile bool) (*payment.CreatePaymentResponse, error) {
if isMobile {
param := alipay.TradeWapPay{}
param.OutTradeNo = req.OrderID
@@ -111,7 +123,7 @@ func (a *Alipay) createTrade(client *alipay.Client, req payment.CreatePaymentReq
param.NotifyURL = notifyURL
param.ReturnURL = returnURL
- payURL, err := client.TradeWapPay(param)
+ payURL, err := alipayTradeWapPay(client, param)
if err != nil {
return nil, fmt.Errorf("alipay TradeWapPay: %w", err)
}
@@ -121,22 +133,19 @@ func (a *Alipay) createTrade(client *alipay.Client, req payment.CreatePaymentReq
}, nil
}
- param := alipay.TradePagePay{}
+ param := alipay.TradePreCreate{}
param.OutTradeNo = req.OrderID
param.TotalAmount = req.Amount
param.Subject = req.Subject
- param.ProductCode = alipayProductCodePagePay
param.NotifyURL = notifyURL
- param.ReturnURL = returnURL
- payURL, err := client.TradePagePay(param)
+ resp, err := alipayTradePreCreate(ctx, client, param)
if err != nil {
- return nil, fmt.Errorf("alipay TradePagePay: %w", err)
+ return nil, fmt.Errorf("alipay TradePreCreate: %w", err)
}
return &payment.CreatePaymentResponse{
TradeNo: req.OrderID,
- PayURL: payURL.String(),
- QRCode: payURL.String(),
+ QRCode: strings.TrimSpace(resp.QRCode),
}, nil
}
diff --git a/backend/internal/payment/provider/alipay_test.go b/backend/internal/payment/provider/alipay_test.go
index 7b0ce0d8..6cc4246c 100644
--- a/backend/internal/payment/provider/alipay_test.go
+++ b/backend/internal/payment/provider/alipay_test.go
@@ -3,9 +3,14 @@
package provider
import (
+ "context"
"errors"
+ "net/url"
"strings"
"testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ "github.com/smartwalle/alipay/v3"
)
func TestIsTradeNotExist(t *testing.T) {
@@ -130,3 +135,111 @@ func TestNewAlipay(t *testing.T) {
})
}
}
+
+func TestCreateTradeUsesPreCreateForDesktop(t *testing.T) {
+ origPreCreate := alipayTradePreCreate
+ origPagePay := alipayTradePagePay
+ origWapPay := alipayTradeWapPay
+ t.Cleanup(func() {
+ alipayTradePreCreate = origPreCreate
+ alipayTradePagePay = origPagePay
+ alipayTradeWapPay = origWapPay
+ })
+
+ preCreateCalls := 0
+ pagePayCalls := 0
+ wapPayCalls := 0
+ alipayTradePreCreate = func(ctx context.Context, client *alipay.Client, param alipay.TradePreCreate) (*alipay.TradePreCreateRsp, error) {
+ preCreateCalls++
+ if param.OutTradeNo != "sub2_100" {
+ t.Fatalf("out_trade_no = %q, want %q", param.OutTradeNo, "sub2_100")
+ }
+ if param.NotifyURL != "https://merchant.example.com/api/v1/payment/webhook/alipay" {
+ t.Fatalf("notify_url = %q", param.NotifyURL)
+ }
+ return &alipay.TradePreCreateRsp{
+ OutTradeNo: "sub2_100",
+ QRCode: "https://qr.alipay.example.com/precreate-token",
+ }, nil
+ }
+ alipayTradePagePay = func(client *alipay.Client, param alipay.TradePagePay) (*url.URL, error) {
+ pagePayCalls++
+ return url.Parse("https://openapi.alipay.com/gateway.do?page-pay")
+ }
+ alipayTradeWapPay = func(client *alipay.Client, param alipay.TradeWapPay) (*url.URL, error) {
+ wapPayCalls++
+ return url.Parse("https://openapi.alipay.com/gateway.do?wap-pay")
+ }
+
+ provider := &Alipay{}
+ resp, err := provider.createTrade(context.Background(), &alipay.Client{}, payment.CreatePaymentRequest{
+ OrderID: "sub2_100",
+ Amount: "88.00",
+ Subject: "Balance recharge",
+ }, "https://merchant.example.com/api/v1/payment/webhook/alipay", "https://merchant.example.com/payment/result", false)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if preCreateCalls != 1 {
+ t.Fatalf("precreate calls = %d, want 1", preCreateCalls)
+ }
+ if pagePayCalls != 0 {
+ t.Fatalf("page pay calls = %d, want 0", pagePayCalls)
+ }
+ if wapPayCalls != 0 {
+ t.Fatalf("wap pay calls = %d, want 0", wapPayCalls)
+ }
+ if resp.QRCode != "https://qr.alipay.example.com/precreate-token" {
+ t.Fatalf("qr_code = %q", resp.QRCode)
+ }
+ if resp.PayURL != "" {
+ t.Fatalf("pay_url = %q, want empty", resp.PayURL)
+ }
+}
+
+func TestCreateTradeUsesWapPayForMobile(t *testing.T) {
+ origPreCreate := alipayTradePreCreate
+ origWapPay := alipayTradeWapPay
+ t.Cleanup(func() {
+ alipayTradePreCreate = origPreCreate
+ alipayTradeWapPay = origWapPay
+ })
+
+ preCreateCalls := 0
+ alipayTradePreCreate = func(ctx context.Context, client *alipay.Client, param alipay.TradePreCreate) (*alipay.TradePreCreateRsp, error) {
+ preCreateCalls++
+ return &alipay.TradePreCreateRsp{}, nil
+ }
+
+ wapPayCalls := 0
+ alipayTradeWapPay = func(client *alipay.Client, param alipay.TradeWapPay) (*url.URL, error) {
+ wapPayCalls++
+ if param.ReturnURL != "https://merchant.example.com/payment/result" {
+ t.Fatalf("return_url = %q", param.ReturnURL)
+ }
+ return url.Parse("https://openapi.alipay.com/gateway.do?wap-pay")
+ }
+
+ provider := &Alipay{}
+ resp, err := provider.createTrade(context.Background(), &alipay.Client{}, payment.CreatePaymentRequest{
+ OrderID: "sub2_101",
+ Amount: "18.00",
+ Subject: "Balance recharge",
+ IsMobile: true,
+ }, "https://merchant.example.com/api/v1/payment/webhook/alipay", "https://merchant.example.com/payment/result", true)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if preCreateCalls != 0 {
+ t.Fatalf("precreate calls = %d, want 0", preCreateCalls)
+ }
+ if wapPayCalls != 1 {
+ t.Fatalf("wap pay calls = %d, want 1", wapPayCalls)
+ }
+ if resp.PayURL == "" {
+ t.Fatal("expected pay_url for mobile wap pay")
+ }
+ if resp.QRCode != "" {
+ t.Fatalf("qr_code = %q, want empty", resp.QRCode)
+ }
+}
diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts
index 7d058a74..e17ed616 100644
--- a/frontend/src/i18n/locales/en.ts
+++ b/frontend/src/i18n/locales/en.ts
@@ -5453,6 +5453,18 @@ export default {
errors: {
tooManyPending: 'Too many pending orders (max {max}). Please complete or cancel existing orders first.',
cancelRateLimited: 'Too many cancellations. Please try again later.',
+ wechatH5NotAuthorized: 'This merchant has not enabled WeChat H5 payment. Open this page in WeChat to continue.',
+ wechatPaymentMpNotConfigured: 'This site has not completed WeChat MP/JSAPI payment setup, so in-app WeChat payment is unavailable right now.',
+ wechatJsapiUnavailable: 'WeChat payment could not be invoked in the current environment. Reopen this page inside WeChat and try again.',
+ wechatJsapiFailed: 'WeChat payment did not complete. Try invoking it again or switch to QR payment.',
+ wechatUnavailable: 'WeChat payment is temporarily unavailable. Please try again later.',
+ wechatOpenInWeChatHint: 'Open the current page inside WeChat, or switch to desktop WeChat QR payment.',
+ wechatScanOnDesktopHint: 'On desktop, use WeChat Scan to pay; on mobile, reopen the current page inside WeChat.',
+ wechatSwitchBrowserHint: 'Switch to desktop WeChat QR payment, or reopen this page in an external browser and retry.',
+ alipayDesktopUnavailable: 'The desktop Alipay flow could not generate a QR code.',
+ alipayDesktopQrHint: 'Desktop Alipay should render a QR code. Refresh and retry, or make sure the payment page was not blocked.',
+ alipayMobileUnavailable: 'This page could not hand off to Alipay.',
+ alipayMobileOpenHint: 'Allow the current page to open the Alipay app, or retry from the system browser.',
PENDING_ORDERS: 'This provider has pending orders. Please wait for them to complete before making changes.',
},
stripePay: 'Pay Now',
diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts
index 6dd74334..d54c0aba 100644
--- a/frontend/src/i18n/locales/zh.ts
+++ b/frontend/src/i18n/locales/zh.ts
@@ -5641,6 +5641,18 @@ export default {
errors: {
tooManyPending: '待支付订单过多(最多 {max} 个),请先完成或取消现有订单',
cancelRateLimited: '取消订单过于频繁,请稍后再试',
+ wechatH5NotAuthorized: '当前商户未开通微信 H5 支付,请在微信中打开当前页面继续支付。',
+ wechatPaymentMpNotConfigured: '当前站点未完成公众号/JSAPI 支付配置,暂时无法在微信内直接拉起支付。',
+ wechatJsapiUnavailable: '当前环境未能拉起微信支付,请确认正在微信内打开本页后重试。',
+ wechatJsapiFailed: '微信支付未完成,请重新拉起支付或改用扫码支付。',
+ wechatUnavailable: '当前微信支付暂不可用,请稍后重试。',
+ wechatOpenInWeChatHint: '请复制当前页面链接到微信内打开,或直接改用电脑端微信扫码支付。',
+ wechatScanOnDesktopHint: '电脑端请直接使用微信扫一扫完成支付;移动端请在微信内打开当前页面。',
+ wechatSwitchBrowserHint: '请改用电脑端微信扫码,或在外部浏览器重新打开本页后再试。',
+ alipayDesktopUnavailable: '当前支付宝桌面支付未成功生成二维码。',
+ alipayDesktopQrHint: '电脑端支付宝应展示扫码单,请刷新后重试,或确认浏览器未拦截当前支付页。',
+ alipayMobileUnavailable: '当前页面未成功跳转到支付宝。',
+ alipayMobileOpenHint: '请允许当前页面打开支付宝 App,或改用系统浏览器重新发起支付。',
PENDING_ORDERS: '该服务商有未完成的订单,请等待订单完成后再操作',
},
stripePay: '立即支付',
diff --git a/frontend/src/views/user/PaymentResultView.vue b/frontend/src/views/user/PaymentResultView.vue
index 9687d1c7..53bbb550 100644
--- a/frontend/src/views/user/PaymentResultView.vue
+++ b/frontend/src/views/user/PaymentResultView.vue
@@ -54,7 +54,7 @@
{{ t('payment.orders.paymentMethod') }}
- {{ t('payment.methods.' + order.payment_type, order.payment_type) }}
+ {{ t(paymentMethodI18nKey(order.payment_type), normalizedOrderPaymentType(order.payment_type)) }}
{{ t('payment.orders.status') }}
@@ -75,7 +75,7 @@
{{ t('payment.orders.paymentMethod') }}
- {{ t('payment.methods.' + returnInfo.type, returnInfo.type) }}
+ {{ t(paymentMethodI18nKey(returnInfo.type), normalizedOrderPaymentType(returnInfo.type)) }}
@@ -98,6 +98,7 @@ import { PAYMENT_RECOVERY_STORAGE_KEY, readPaymentRecoverySnapshot } from '@/com
import { usePaymentStore } from '@/stores/payment'
import { paymentAPI } from '@/api/payment'
import type { PaymentOrder } from '@/types/payment'
+import { normalizePaymentMethodForDisplay, paymentMethodI18nKey } from './paymentUx'
const { t } = useI18n()
const route = useRoute()
@@ -133,6 +134,10 @@ const isSuccess = computed(() => {
return !!order.value && SUCCESS_STATUSES.has(order.value.status)
})
+function normalizedOrderPaymentType(paymentType: string): string {
+ return normalizePaymentMethodForDisplay(paymentType) || paymentType
+}
+
onMounted(async () => {
const resumeToken = typeof route.query.resume_token === 'string'
? route.query.resume_token
diff --git a/frontend/src/views/user/PaymentView.vue b/frontend/src/views/user/PaymentView.vue
index bfb9dae2..019de16a 100644
--- a/frontend/src/views/user/PaymentView.vue
+++ b/frontend/src/views/user/PaymentView.vue
@@ -88,6 +88,7 @@
{{ errorMessage }}
+
{{ errorHintMessage }}
@@ -174,6 +175,7 @@
{{ t('common.cancel') }}
{{ errorMessage }}
+
{{ errorHintMessage }}
@@ -281,6 +283,7 @@ import SubscriptionPlanCard from '@/components/payment/SubscriptionPlanCard.vue'
import PaymentStatusPanel from '@/components/payment/PaymentStatusPanel.vue'
import Icon from '@/components/icons/Icon.vue'
import type { PaymentMethodOption } from '@/components/payment/PaymentMethodSelector.vue'
+import { describePaymentScenarioError } from './paymentUx'
const { t } = useI18n()
const route = useRoute()
@@ -301,6 +304,7 @@ function getDaysRemaining(expiresAt: string): number {
const loading = ref(true)
const submitting = ref(false)
const errorMessage = ref('')
+const errorHintMessage = ref('')
const activeTab = ref<'recharge' | 'subscription'>('recharge')
const amount = ref
(null)
const selectedMethod = ref('')
@@ -619,6 +623,7 @@ async function confirmSubscribe() {
async function createOrder(orderAmount: number, orderType: OrderType, planId?: number, options: CreateOrderOptions = {}) {
submitting.value = true
errorMessage.value = ''
+ errorHintMessage.value = ''
try {
const requestType = normalizeVisibleMethod(options.paymentType || selectedMethod.value) || options.paymentType || selectedMethod.value
const payload = buildCreateOrderPayload({
@@ -668,8 +673,7 @@ async function createOrder(orderAmount: number, orderType: OrderType, planId?: n
}
if (decision.kind === 'unhandled') {
- errorMessage.value = t('payment.result.failed')
- appStore.showError(errorMessage.value)
+ applyScenarioError({ reason: 'UNHANDLED_PAYMENT_SCENARIO' }, visibleMethod)
return
}
@@ -691,7 +695,7 @@ async function createOrder(orderAmount: number, orderType: OrderType, planId?: n
if (errMsg.includes('cancel')) {
appStore.showInfo(t('payment.qr.cancelled'))
} else if (errMsg && !errMsg.includes('ok')) {
- appStore.showError(t('payment.result.failed'))
+ applyScenarioError({ reason: 'WECHAT_JSAPI_FAILED', message: errMsg }, visibleMethod)
}
return
}
@@ -707,10 +711,16 @@ async function createOrder(orderAmount: number, orderType: OrderType, planId?: n
if (apiErr.reason === 'TOO_MANY_PENDING') {
const metadata = apiErr.metadata as Record | undefined
errorMessage.value = t('payment.errors.tooManyPending', { max: metadata?.max || '' })
+ errorHintMessage.value = ''
} else if (apiErr.reason === 'CANCEL_RATE_LIMITED') {
errorMessage.value = t('payment.errors.cancelRateLimited')
+ errorHintMessage.value = ''
} else {
- errorMessage.value = extractApiErrorMessage(err, t('payment.result.failed'))
+ applyScenarioError(err, normalizeVisibleMethod(options.paymentType || selectedMethod.value) || selectedMethod.value)
+ if (!errorMessage.value) {
+ errorMessage.value = extractApiErrorMessage(err, t('payment.result.failed'))
+ errorHintMessage.value = ''
+ }
}
appStore.showError(errorMessage.value)
} finally {
@@ -718,6 +728,21 @@ async function createOrder(orderAmount: number, orderType: OrderType, planId?: n
}
}
+function applyScenarioError(err: unknown, paymentMethod: string) {
+ const descriptor = describePaymentScenarioError(err, {
+ paymentMethod,
+ isMobile: isMobileDevice(),
+ isWechatBrowser: typeof window !== 'undefined' && /MicroMessenger/i.test(window.navigator.userAgent),
+ })
+ if (!descriptor) {
+ errorMessage.value = ''
+ errorHintMessage.value = ''
+ return
+ }
+ errorMessage.value = t(descriptor.messageKey)
+ errorHintMessage.value = descriptor.hintKey ? t(descriptor.hintKey) : ''
+}
+
async function resumeWechatPaymentFromQuery() {
const openid = readRouteQueryValue(route.query.openid)
if (readRouteQueryValue(route.query.wechat_resume) !== '1' || !openid) {
diff --git a/frontend/src/views/user/__tests__/PaymentResultView.spec.ts b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts
index d23a60d9..b1caa526 100644
--- a/frontend/src/views/user/__tests__/PaymentResultView.spec.ts
+++ b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts
@@ -155,4 +155,29 @@ describe('PaymentResultView', () => {
expect(wrapper.text()).toContain('payment.result.success')
expect(verifyOrderPublic).not.toHaveBeenCalled()
})
+
+ it('normalizes aliased payment methods before rendering the label', async () => {
+ routeState.query = {
+ resume_token: 'resume-88',
+ }
+ resolveOrderPublicByResumeToken.mockResolvedValueOnce({
+ data: {
+ ...orderFactory('PAID'),
+ payment_type: 'alipay_direct',
+ },
+ })
+
+ const wrapper = mount(PaymentResultView, {
+ global: {
+ stubs: {
+ OrderStatusBadge: true,
+ },
+ },
+ })
+
+ await flushPromises()
+
+ expect(wrapper.text()).toContain('payment.methods.alipay')
+ expect(wrapper.text()).not.toContain('payment.methods.alipay_direct')
+ })
})
diff --git a/frontend/src/views/user/__tests__/paymentUx.spec.ts b/frontend/src/views/user/__tests__/paymentUx.spec.ts
new file mode 100644
index 00000000..6cf105f2
--- /dev/null
+++ b/frontend/src/views/user/__tests__/paymentUx.spec.ts
@@ -0,0 +1,49 @@
+import { describe, expect, it } from 'vitest'
+import {
+ describePaymentScenarioError,
+ normalizePaymentMethodForDisplay,
+} from '../paymentUx'
+
+describe('normalizePaymentMethodForDisplay', () => {
+ it('collapses visible payment aliases to canonical method ids', () => {
+ expect(normalizePaymentMethodForDisplay(' alipay_direct ')).toBe('alipay')
+ expect(normalizePaymentMethodForDisplay('wxpay_direct')).toBe('wxpay')
+ expect(normalizePaymentMethodForDisplay('wechat_pay')).toBe('wxpay')
+ })
+
+ it('leaves non-aliased methods untouched', () => {
+ expect(normalizePaymentMethodForDisplay('stripe')).toBe('stripe')
+ })
+})
+
+describe('describePaymentScenarioError', () => {
+ it('maps WeChat H5 authorization errors to explicit in-app guidance', () => {
+ expect(describePaymentScenarioError(
+ { reason: 'WECHAT_H5_NOT_AUTHORIZED' },
+ { paymentMethod: 'wxpay', isMobile: true, isWechatBrowser: false },
+ )).toEqual({
+ messageKey: 'payment.errors.wechatH5NotAuthorized',
+ hintKey: 'payment.errors.wechatOpenInWeChatHint',
+ })
+ })
+
+ it('maps missing WeixinJSBridge to a JSAPI-specific prompt', () => {
+ expect(describePaymentScenarioError(
+ new Error('WeixinJSBridge is unavailable'),
+ { paymentMethod: 'wxpay', isMobile: true, isWechatBrowser: true },
+ )).toEqual({
+ messageKey: 'payment.errors.wechatJsapiUnavailable',
+ hintKey: 'payment.errors.wechatOpenInWeChatHint',
+ })
+ })
+
+ it('maps generic desktop Alipay failures to QR guidance', () => {
+ expect(describePaymentScenarioError(
+ { reason: 'PAYMENT_GATEWAY_ERROR' },
+ { paymentMethod: 'alipay', isMobile: false, isWechatBrowser: false },
+ )).toEqual({
+ messageKey: 'payment.errors.alipayDesktopUnavailable',
+ hintKey: 'payment.errors.alipayDesktopQrHint',
+ })
+ })
+})
diff --git a/frontend/src/views/user/paymentUx.ts b/frontend/src/views/user/paymentUx.ts
new file mode 100644
index 00000000..443529a7
--- /dev/null
+++ b/frontend/src/views/user/paymentUx.ts
@@ -0,0 +1,105 @@
+import { normalizeVisibleMethod } from '@/components/payment/paymentFlow'
+import { extractApiErrorCode } from '@/utils/apiError'
+
+const DISPLAY_METHOD_ALIASES: Record = {
+ wechat: 'wxpay',
+ wechat_pay: 'wxpay',
+}
+
+export interface PaymentScenarioContext {
+ paymentMethod: string
+ isMobile: boolean
+ isWechatBrowser: boolean
+}
+
+export interface PaymentScenarioErrorDescriptor {
+ messageKey: string
+ hintKey?: string
+}
+
+export function normalizePaymentMethodForDisplay(paymentType: string): string {
+ const trimmed = paymentType.trim().toLowerCase()
+ const visibleMethod = normalizeVisibleMethod(trimmed)
+ if (visibleMethod) return visibleMethod
+ return DISPLAY_METHOD_ALIASES[trimmed] ?? trimmed
+}
+
+export function paymentMethodI18nKey(paymentType: string): string {
+ return `payment.methods.${normalizePaymentMethodForDisplay(paymentType)}`
+}
+
+function defaultWechatHint(context: PaymentScenarioContext): string {
+ if (!context.isMobile) return 'payment.errors.wechatScanOnDesktopHint'
+ return 'payment.errors.wechatOpenInWeChatHint'
+}
+
+function defaultAlipayHint(context: PaymentScenarioContext): string {
+ if (context.isMobile) return 'payment.errors.alipayMobileOpenHint'
+ return 'payment.errors.alipayDesktopQrHint'
+}
+
+export function describePaymentScenarioError(
+ error: unknown,
+ context: PaymentScenarioContext,
+): PaymentScenarioErrorDescriptor | null {
+ const method = normalizePaymentMethodForDisplay(context.paymentMethod)
+ const code = extractApiErrorCode(error)
+ const message = error instanceof Error
+ ? error.message
+ : (typeof error === 'object' && error && 'message' in error && typeof error.message === 'string'
+ ? error.message
+ : String(error || ''))
+ const normalizedMessage = message.toLowerCase()
+
+ if (method === 'wxpay') {
+ if (code === 'WECHAT_H5_NOT_AUTHORIZED') {
+ return {
+ messageKey: 'payment.errors.wechatH5NotAuthorized',
+ hintKey: defaultWechatHint(context),
+ }
+ }
+ if (code === 'WECHAT_PAYMENT_MP_NOT_CONFIGURED') {
+ return {
+ messageKey: 'payment.errors.wechatPaymentMpNotConfigured',
+ hintKey: context.isWechatBrowser
+ ? 'payment.errors.wechatSwitchBrowserHint'
+ : defaultWechatHint(context),
+ }
+ }
+ if (code === 'NO_AVAILABLE_INSTANCE') {
+ return {
+ messageKey: 'payment.errors.wechatUnavailable',
+ hintKey: defaultWechatHint(context),
+ }
+ }
+ if (code === 'WECHAT_JSAPI_FAILED' || normalizedMessage.includes('get_brand_wcpay_request:fail')) {
+ return {
+ messageKey: 'payment.errors.wechatJsapiFailed',
+ hintKey: defaultWechatHint(context),
+ }
+ }
+ if (normalizedMessage.includes('weixinjsbridge is unavailable')) {
+ return {
+ messageKey: 'payment.errors.wechatJsapiUnavailable',
+ hintKey: 'payment.errors.wechatOpenInWeChatHint',
+ }
+ }
+ if (code === 'PAYMENT_GATEWAY_ERROR' || code === 'UNHANDLED_PAYMENT_SCENARIO') {
+ return {
+ messageKey: 'payment.errors.wechatUnavailable',
+ hintKey: defaultWechatHint(context),
+ }
+ }
+ }
+
+ if (method === 'alipay' && (code === 'PAYMENT_GATEWAY_ERROR' || code === 'UNHANDLED_PAYMENT_SCENARIO')) {
+ return {
+ messageKey: context.isMobile
+ ? 'payment.errors.alipayMobileUnavailable'
+ : 'payment.errors.alipayDesktopUnavailable',
+ hintKey: defaultAlipayHint(context),
+ }
+ }
+
+ return null
+}
--
GitLab
From 9e84e2fd2bed765d38ca4a438fbbf1739950ae32 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 00:05:17 +0800
Subject: [PATCH 075/261] fix: persist admin payment visibility and scheduler
settings
---
.../internal/handler/admin/setting_handler.go | 64 ++++++++++++++++
...tting_handler_auth_source_defaults_test.go | 74 +++++++++++++++++++
backend/internal/handler/dto/settings.go | 9 +++
backend/internal/server/api_contract_test.go | 63 +++++++++++++++-
.../service/admin_service_apikey_test.go | 6 ++
backend/internal/service/setting_service.go | 54 +++++++++++++-
.../service/setting_service_update_test.go | 31 ++++++++
backend/internal/service/settings_view.go | 9 +++
backend/internal/service/user_service_test.go | 6 ++
9 files changed, 311 insertions(+), 5 deletions(-)
diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go
index fe5c7928..e5681208 100644
--- a/backend/internal/handler/admin/setting_handler.go
+++ b/backend/internal/handler/admin/setting_handler.go
@@ -180,6 +180,11 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
EnableMetadataPassthrough: settings.EnableMetadataPassthrough,
EnableCCHSigning: settings.EnableCCHSigning,
WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled,
+ PaymentVisibleMethodAlipaySource: settings.PaymentVisibleMethodAlipaySource,
+ PaymentVisibleMethodWxpaySource: settings.PaymentVisibleMethodWxpaySource,
+ PaymentVisibleMethodAlipayEnabled: settings.PaymentVisibleMethodAlipayEnabled,
+ PaymentVisibleMethodWxpayEnabled: settings.PaymentVisibleMethodWxpayEnabled,
+ OpenAIAdvancedSchedulerEnabled: settings.OpenAIAdvancedSchedulerEnabled,
BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled,
BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold,
BalanceLowNotifyRechargeURL: settings.BalanceLowNotifyRechargeURL,
@@ -338,6 +343,15 @@ type UpdateSettingsRequest struct {
EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"`
EnableCCHSigning *bool `json:"enable_cch_signing"`
+ // Payment visible method routing
+ PaymentVisibleMethodAlipaySource *string `json:"payment_visible_method_alipay_source"`
+ PaymentVisibleMethodWxpaySource *string `json:"payment_visible_method_wxpay_source"`
+ PaymentVisibleMethodAlipayEnabled *bool `json:"payment_visible_method_alipay_enabled"`
+ PaymentVisibleMethodWxpayEnabled *bool `json:"payment_visible_method_wxpay_enabled"`
+
+ // OpenAI account scheduling
+ OpenAIAdvancedSchedulerEnabled *bool `json:"openai_advanced_scheduler_enabled"`
+
// Balance low notification
BalanceLowNotifyEnabled *bool `json:"balance_low_notify_enabled"`
BalanceLowNotifyThreshold *float64 `json:"balance_low_notify_threshold"`
@@ -935,6 +949,36 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
return previousSettings.EnableCCHSigning
}(),
+ PaymentVisibleMethodAlipaySource: func() string {
+ if req.PaymentVisibleMethodAlipaySource != nil {
+ return strings.TrimSpace(*req.PaymentVisibleMethodAlipaySource)
+ }
+ return previousSettings.PaymentVisibleMethodAlipaySource
+ }(),
+ PaymentVisibleMethodWxpaySource: func() string {
+ if req.PaymentVisibleMethodWxpaySource != nil {
+ return strings.TrimSpace(*req.PaymentVisibleMethodWxpaySource)
+ }
+ return previousSettings.PaymentVisibleMethodWxpaySource
+ }(),
+ PaymentVisibleMethodAlipayEnabled: func() bool {
+ if req.PaymentVisibleMethodAlipayEnabled != nil {
+ return *req.PaymentVisibleMethodAlipayEnabled
+ }
+ return previousSettings.PaymentVisibleMethodAlipayEnabled
+ }(),
+ PaymentVisibleMethodWxpayEnabled: func() bool {
+ if req.PaymentVisibleMethodWxpayEnabled != nil {
+ return *req.PaymentVisibleMethodWxpayEnabled
+ }
+ return previousSettings.PaymentVisibleMethodWxpayEnabled
+ }(),
+ OpenAIAdvancedSchedulerEnabled: func() bool {
+ if req.OpenAIAdvancedSchedulerEnabled != nil {
+ return *req.OpenAIAdvancedSchedulerEnabled
+ }
+ return previousSettings.OpenAIAdvancedSchedulerEnabled
+ }(),
BalanceLowNotifyEnabled: func() bool {
if req.BalanceLowNotifyEnabled != nil {
return *req.BalanceLowNotifyEnabled
@@ -1153,6 +1197,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
EnableFingerprintUnification: updatedSettings.EnableFingerprintUnification,
EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough,
EnableCCHSigning: updatedSettings.EnableCCHSigning,
+ PaymentVisibleMethodAlipaySource: updatedSettings.PaymentVisibleMethodAlipaySource,
+ PaymentVisibleMethodWxpaySource: updatedSettings.PaymentVisibleMethodWxpaySource,
+ PaymentVisibleMethodAlipayEnabled: updatedSettings.PaymentVisibleMethodAlipayEnabled,
+ PaymentVisibleMethodWxpayEnabled: updatedSettings.PaymentVisibleMethodWxpayEnabled,
+ OpenAIAdvancedSchedulerEnabled: updatedSettings.OpenAIAdvancedSchedulerEnabled,
BalanceLowNotifyEnabled: updatedSettings.BalanceLowNotifyEnabled,
BalanceLowNotifyThreshold: updatedSettings.BalanceLowNotifyThreshold,
BalanceLowNotifyRechargeURL: updatedSettings.BalanceLowNotifyRechargeURL,
@@ -1455,6 +1504,21 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.EnableCCHSigning != after.EnableCCHSigning {
changed = append(changed, "enable_cch_signing")
}
+ if before.PaymentVisibleMethodAlipaySource != after.PaymentVisibleMethodAlipaySource {
+ changed = append(changed, "payment_visible_method_alipay_source")
+ }
+ if before.PaymentVisibleMethodWxpaySource != after.PaymentVisibleMethodWxpaySource {
+ changed = append(changed, "payment_visible_method_wxpay_source")
+ }
+ if before.PaymentVisibleMethodAlipayEnabled != after.PaymentVisibleMethodAlipayEnabled {
+ changed = append(changed, "payment_visible_method_alipay_enabled")
+ }
+ if before.PaymentVisibleMethodWxpayEnabled != after.PaymentVisibleMethodWxpayEnabled {
+ changed = append(changed, "payment_visible_method_wxpay_enabled")
+ }
+ if before.OpenAIAdvancedSchedulerEnabled != after.OpenAIAdvancedSchedulerEnabled {
+ changed = append(changed, "openai_advanced_scheduler_enabled")
+ }
// Balance & quota notification
if before.BalanceLowNotifyEnabled != after.BalanceLowNotifyEnabled {
changed = append(changed, "balance_low_notify_enabled")
diff --git a/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go
index b26fa447..bf51fc68 100644
--- a/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go
+++ b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go
@@ -147,3 +147,77 @@ func TestSettingHandler_UpdateSettings_PreservesOmittedAuthSourceDefaults(t *tes
require.Equal(t, float64(8), data["auth_source_default_email_concurrency"])
require.Equal(t, true, data["force_email_on_third_party_signup"])
}
+
+func TestSettingHandler_UpdateSettings_PersistsPaymentVisibleMethodsAndAdvancedScheduler(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ repo := &settingHandlerRepoStub{
+ values: map[string]string{
+ service.SettingKeyPromoCodeEnabled: "true",
+ },
+ }
+ svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
+ handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
+
+ body := map[string]any{
+ "promo_code_enabled": true,
+ "payment_visible_method_alipay_source": "easypay",
+ "payment_visible_method_wxpay_source": "wxpay",
+ "payment_visible_method_alipay_enabled": true,
+ "payment_visible_method_wxpay_enabled": false,
+ "openai_advanced_scheduler_enabled": true,
+ }
+ rawBody, err := json.Marshal(body)
+ require.NoError(t, err)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ handler.UpdateSettings(c)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, service.VisibleMethodSourceEasyPayAlipay, repo.values[service.SettingPaymentVisibleMethodAlipaySource])
+ require.Equal(t, service.VisibleMethodSourceOfficialWechat, repo.values[service.SettingPaymentVisibleMethodWxpaySource])
+ require.Equal(t, "true", repo.values[service.SettingPaymentVisibleMethodAlipayEnabled])
+ require.Equal(t, "false", repo.values[service.SettingPaymentVisibleMethodWxpayEnabled])
+ require.Equal(t, "true", repo.values["openai_advanced_scheduler_enabled"])
+
+ var resp response.Response
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
+ data, ok := resp.Data.(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, service.VisibleMethodSourceEasyPayAlipay, data["payment_visible_method_alipay_source"])
+ require.Equal(t, service.VisibleMethodSourceOfficialWechat, data["payment_visible_method_wxpay_source"])
+ require.Equal(t, true, data["payment_visible_method_alipay_enabled"])
+ require.Equal(t, false, data["payment_visible_method_wxpay_enabled"])
+ require.Equal(t, true, data["openai_advanced_scheduler_enabled"])
+}
+
+func TestSettingHandler_UpdateSettings_RejectsInvalidPaymentVisibleMethodSource(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ repo := &settingHandlerRepoStub{
+ values: map[string]string{
+ service.SettingKeyPromoCodeEnabled: "true",
+ },
+ }
+ svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
+ handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
+
+ body := map[string]any{
+ "promo_code_enabled": true,
+ "payment_visible_method_alipay_source": "bogus",
+ }
+ rawBody, err := json.Marshal(body)
+ require.NoError(t, err)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ handler.UpdateSettings(c)
+
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+ require.NotContains(t, repo.values, service.SettingPaymentVisibleMethodAlipaySource)
+}
diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go
index 637e317b..cc3f8496 100644
--- a/backend/internal/handler/dto/settings.go
+++ b/backend/internal/handler/dto/settings.go
@@ -127,6 +127,15 @@ type SystemSettings struct {
// Web Search Emulation
WebSearchEmulationEnabled bool `json:"web_search_emulation_enabled"`
+ // Payment visible method routing
+ PaymentVisibleMethodAlipaySource string `json:"payment_visible_method_alipay_source"`
+ PaymentVisibleMethodWxpaySource string `json:"payment_visible_method_wxpay_source"`
+ PaymentVisibleMethodAlipayEnabled bool `json:"payment_visible_method_alipay_enabled"`
+ PaymentVisibleMethodWxpayEnabled bool `json:"payment_visible_method_wxpay_enabled"`
+
+ // OpenAI account scheduling
+ OpenAIAdvancedSchedulerEnabled bool `json:"openai_advanced_scheduler_enabled"`
+
// Payment configuration
PaymentEnabled bool `json:"payment_enabled"`
PaymentMinAmount float64 `json:"payment_min_amount"`
diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go
index e903898f..533f2dac 100644
--- a/backend/internal/server/api_contract_test.go
+++ b/backend/internal/server/api_contract_test.go
@@ -500,10 +500,15 @@ func TestAPIContracts(t *testing.T) {
service.SettingKeyTableDefaultPageSize: "20",
service.SettingKeyTablePageSizeOptions: "[10,20,50,100]",
- service.SettingKeyOpsMonitoringEnabled: "false",
- service.SettingKeyOpsRealtimeMonitoringEnabled: "true",
- service.SettingKeyOpsQueryModeDefault: "auto",
- service.SettingKeyOpsMetricsIntervalSeconds: "60",
+ service.SettingKeyOpsMonitoringEnabled: "false",
+ service.SettingKeyOpsRealtimeMonitoringEnabled: "true",
+ service.SettingKeyOpsQueryModeDefault: "auto",
+ service.SettingKeyOpsMetricsIntervalSeconds: "60",
+ service.SettingPaymentVisibleMethodAlipaySource: service.VisibleMethodSourceEasyPayAlipay,
+ service.SettingPaymentVisibleMethodWxpaySource: service.VisibleMethodSourceOfficialWechat,
+ service.SettingPaymentVisibleMethodAlipayEnabled: "true",
+ service.SettingPaymentVisibleMethodWxpayEnabled: "false",
+ "openai_advanced_scheduler_enabled": "true",
})
},
method: http.MethodGet,
@@ -567,6 +572,27 @@ func TestAPIContracts(t *testing.T) {
"api_base_url": "https://api.example.com",
"contact_info": "support",
"doc_url": "https://docs.example.com",
+ "auth_source_default_email_balance": 0,
+ "auth_source_default_email_concurrency": 5,
+ "auth_source_default_email_subscriptions": [],
+ "auth_source_default_email_grant_on_signup": false,
+ "auth_source_default_email_grant_on_first_bind": false,
+ "auth_source_default_linuxdo_balance": 0,
+ "auth_source_default_linuxdo_concurrency": 5,
+ "auth_source_default_linuxdo_subscriptions": [],
+ "auth_source_default_linuxdo_grant_on_signup": false,
+ "auth_source_default_linuxdo_grant_on_first_bind": false,
+ "auth_source_default_oidc_balance": 0,
+ "auth_source_default_oidc_concurrency": 5,
+ "auth_source_default_oidc_subscriptions": [],
+ "auth_source_default_oidc_grant_on_signup": false,
+ "auth_source_default_oidc_grant_on_first_bind": false,
+ "auth_source_default_wechat_balance": 0,
+ "auth_source_default_wechat_concurrency": 5,
+ "auth_source_default_wechat_subscriptions": [],
+ "auth_source_default_wechat_grant_on_signup": false,
+ "auth_source_default_wechat_grant_on_first_bind": false,
+ "force_email_on_third_party_signup": false,
"default_concurrency": 5,
"default_balance": 1.25,
"default_subscriptions": [],
@@ -592,6 +618,11 @@ func TestAPIContracts(t *testing.T) {
"enable_fingerprint_unification": true,
"enable_metadata_passthrough": false,
"web_search_emulation_enabled": false,
+ "payment_visible_method_alipay_source": "easypay_alipay",
+ "payment_visible_method_wxpay_source": "official_wxpay",
+ "payment_visible_method_alipay_enabled": true,
+ "payment_visible_method_wxpay_enabled": false,
+ "openai_advanced_scheduler_enabled": true,
"custom_menu_items": [],
"custom_endpoints": [],
"payment_enabled": false,
@@ -858,6 +889,18 @@ func (r *stubUserRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
+func (r *stubUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*service.UserAvatar, error) {
+ return nil, nil
+}
+
+func (r *stubUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
+ return nil, errors.New("not implemented")
+}
+
+func (r *stubUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error {
+ return errors.New("not implemented")
+}
+
func (r *stubUserRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
@@ -894,6 +937,18 @@ func (r *stubUserRepo) AddGroupToAllowedGroups(ctx context.Context, userID int64
return errors.New("not implemented")
}
+func (r *stubUserRepo) ListUserAuthIdentities(ctx context.Context, userID int64) ([]service.UserAuthIdentityRecord, error) {
+ return nil, errors.New("not implemented")
+}
+
+func (r *stubUserRepo) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) {
+ return map[int64]*time.Time{}, nil
+}
+
+func (r *stubUserRepo) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) {
+ return nil, nil
+}
+
func (r *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
return errors.New("not implemented")
}
diff --git a/backend/internal/service/admin_service_apikey_test.go b/backend/internal/service/admin_service_apikey_test.go
index 487fb5f1..e2eae0b4 100644
--- a/backend/internal/service/admin_service_apikey_test.go
+++ b/backend/internal/service/admin_service_apikey_test.go
@@ -82,6 +82,12 @@ func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error {
func (s *userRepoStubForGroupUpdate) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) {
panic("unexpected")
}
+func (s *userRepoStubForGroupUpdate) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
+ panic("unexpected")
+}
+func (s *userRepoStubForGroupUpdate) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) {
+ panic("unexpected")
+}
func (s *userRepoStubForGroupUpdate) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
panic("unexpected")
}
diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go
index a2644fcd..02a64c1c 100644
--- a/backend/internal/service/setting_service.go
+++ b/backend/internal/service/setting_service.go
@@ -566,6 +566,16 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
normalizedWhitelist = []string{}
}
settings.RegistrationEmailSuffixWhitelist = normalizedWhitelist
+ alipaySource, err := normalizeVisibleMethodSettingSource("alipay", settings.PaymentVisibleMethodAlipaySource, settings.PaymentVisibleMethodAlipayEnabled)
+ if err != nil {
+ return err
+ }
+ wxpaySource, err := normalizeVisibleMethodSettingSource("wxpay", settings.PaymentVisibleMethodWxpaySource, settings.PaymentVisibleMethodWxpayEnabled)
+ if err != nil {
+ return err
+ }
+ settings.PaymentVisibleMethodAlipaySource = alipaySource
+ settings.PaymentVisibleMethodWxpaySource = wxpaySource
updates := make(map[string]string)
@@ -701,6 +711,11 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyEnableFingerprintUnification] = strconv.FormatBool(settings.EnableFingerprintUnification)
updates[SettingKeyEnableMetadataPassthrough] = strconv.FormatBool(settings.EnableMetadataPassthrough)
updates[SettingKeyEnableCCHSigning] = strconv.FormatBool(settings.EnableCCHSigning)
+ updates[SettingPaymentVisibleMethodAlipaySource] = settings.PaymentVisibleMethodAlipaySource
+ updates[SettingPaymentVisibleMethodWxpaySource] = settings.PaymentVisibleMethodWxpaySource
+ updates[SettingPaymentVisibleMethodAlipayEnabled] = strconv.FormatBool(settings.PaymentVisibleMethodAlipayEnabled)
+ updates[SettingPaymentVisibleMethodWxpayEnabled] = strconv.FormatBool(settings.PaymentVisibleMethodWxpayEnabled)
+ updates[openAIAdvancedSchedulerSettingKey] = strconv.FormatBool(settings.OpenAIAdvancedSchedulerEnabled)
// Balance low notification
updates[SettingKeyBalanceLowNotifyEnabled] = strconv.FormatBool(settings.BalanceLowNotifyEnabled)
@@ -730,6 +745,11 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
cchSigning: settings.EnableCCHSigning,
expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(),
})
+ openAIAdvancedSchedulerSettingSF.Forget(openAIAdvancedSchedulerSettingKey)
+ openAIAdvancedSchedulerSettingCache.Store(&cachedOpenAIAdvancedSchedulerSetting{
+ enabled: settings.OpenAIAdvancedSchedulerEnabled,
+ expiresAt: time.Now().Add(openAIAdvancedSchedulerSettingCacheTTL).UnixNano(),
+ })
if s.onUpdate != nil {
s.onUpdate() // Invalidate cache after settings update
}
@@ -1137,7 +1157,12 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeyMaxClaudeCodeVersion: "",
// 分组隔离(默认不允许未分组 Key 调度)
- SettingKeyAllowUngroupedKeyScheduling: "false",
+ SettingKeyAllowUngroupedKeyScheduling: "false",
+ SettingPaymentVisibleMethodAlipaySource: "",
+ SettingPaymentVisibleMethodWxpaySource: "",
+ SettingPaymentVisibleMethodAlipayEnabled: "false",
+ SettingPaymentVisibleMethodWxpayEnabled: "false",
+ openAIAdvancedSchedulerSettingKey: "false",
}
return s.settingRepo.SetMultiple(ctx, defaults)
@@ -1429,6 +1454,11 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result.WebSearchEmulationEnabled = wsCfg.Enabled && len(wsCfg.Providers) > 0
}
}
+ result.PaymentVisibleMethodAlipaySource = NormalizeVisibleMethodSource("alipay", settings[SettingPaymentVisibleMethodAlipaySource])
+ result.PaymentVisibleMethodWxpaySource = NormalizeVisibleMethodSource("wxpay", settings[SettingPaymentVisibleMethodWxpaySource])
+ result.PaymentVisibleMethodAlipayEnabled = settings[SettingPaymentVisibleMethodAlipayEnabled] == "true"
+ result.PaymentVisibleMethodWxpayEnabled = settings[SettingPaymentVisibleMethodWxpayEnabled] == "true"
+ result.OpenAIAdvancedSchedulerEnabled = settings[openAIAdvancedSchedulerSettingKey] == "true"
// Balance low notification
result.BalanceLowNotifyEnabled = settings[SettingKeyBalanceLowNotifyEnabled] == "true"
@@ -1458,6 +1488,28 @@ func isFalseSettingValue(value string) bool {
}
}
+func normalizeVisibleMethodSettingSource(method, source string, enabled bool) (string, error) {
+ source = strings.TrimSpace(source)
+ if source == "" {
+ if enabled {
+ return "", infraerrors.BadRequest(
+ "INVALID_PAYMENT_VISIBLE_METHOD_SOURCE",
+ fmt.Sprintf("%s source is required when the visible method is enabled", method),
+ )
+ }
+ return "", nil
+ }
+
+ normalized := NormalizeVisibleMethodSource(method, source)
+ if normalized == "" {
+ return "", infraerrors.BadRequest(
+ "INVALID_PAYMENT_VISIBLE_METHOD_SOURCE",
+ fmt.Sprintf("%s source must be one of the supported payment providers", method),
+ )
+ }
+ return normalized, nil
+}
+
func parseDefaultSubscriptions(raw string) []DefaultSubscriptionSetting {
raw = strings.TrimSpace(raw)
if raw == "" {
diff --git a/backend/internal/service/setting_service_update_test.go b/backend/internal/service/setting_service_update_test.go
index e62218b4..9dc0ca59 100644
--- a/backend/internal/service/setting_service_update_test.go
+++ b/backend/internal/service/setting_service_update_test.go
@@ -223,3 +223,34 @@ func TestSettingService_UpdateSettings_TablePreferences(t *testing.T) {
require.Equal(t, "1000", repo.updates[SettingKeyTableDefaultPageSize])
require.Equal(t, "[20,100]", repo.updates[SettingKeyTablePageSizeOptions])
}
+
+func TestSettingService_UpdateSettings_PaymentVisibleMethodsAndAdvancedScheduler(t *testing.T) {
+ repo := &settingUpdateRepoStub{}
+ svc := NewSettingService(repo, &config.Config{})
+
+ err := svc.UpdateSettings(context.Background(), &SystemSettings{
+ PaymentVisibleMethodAlipaySource: "alipay",
+ PaymentVisibleMethodWxpaySource: "easypay",
+ PaymentVisibleMethodAlipayEnabled: true,
+ PaymentVisibleMethodWxpayEnabled: false,
+ OpenAIAdvancedSchedulerEnabled: true,
+ })
+ require.NoError(t, err)
+ require.Equal(t, VisibleMethodSourceOfficialAlipay, repo.updates[SettingPaymentVisibleMethodAlipaySource])
+ require.Equal(t, VisibleMethodSourceEasyPayWechat, repo.updates[SettingPaymentVisibleMethodWxpaySource])
+ require.Equal(t, "true", repo.updates[SettingPaymentVisibleMethodAlipayEnabled])
+ require.Equal(t, "false", repo.updates[SettingPaymentVisibleMethodWxpayEnabled])
+ require.Equal(t, "true", repo.updates[openAIAdvancedSchedulerSettingKey])
+}
+
+func TestSettingService_UpdateSettings_RejectsInvalidPaymentVisibleMethodSource(t *testing.T) {
+ repo := &settingUpdateRepoStub{}
+ svc := NewSettingService(repo, &config.Config{})
+
+ err := svc.UpdateSettings(context.Background(), &SystemSettings{
+ PaymentVisibleMethodAlipaySource: "not-a-provider",
+ })
+ require.Error(t, err)
+ require.Equal(t, "INVALID_PAYMENT_VISIBLE_METHOD_SOURCE", infraerrors.Reason(err))
+ require.Nil(t, repo.updates)
+}
diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go
index 72db4e31..9bd461f9 100644
--- a/backend/internal/service/settings_view.go
+++ b/backend/internal/service/settings_view.go
@@ -110,6 +110,15 @@ type SystemSettings struct {
// Web Search Emulation
WebSearchEmulationEnabled bool // 是否启用 web search 模拟
+ // Payment visible method routing
+ PaymentVisibleMethodAlipaySource string
+ PaymentVisibleMethodWxpaySource string
+ PaymentVisibleMethodAlipayEnabled bool
+ PaymentVisibleMethodWxpayEnabled bool
+
+ // OpenAI account scheduling
+ OpenAIAdvancedSchedulerEnabled bool
+
// Balance low notification
BalanceLowNotifyEnabled bool
BalanceLowNotifyThreshold float64
diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go
index 89f0362e..e13fb95d 100644
--- a/backend/internal/service/user_service_test.go
+++ b/backend/internal/service/user_service_test.go
@@ -103,6 +103,12 @@ func (m *mockUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) er
func (m *mockUserRepo) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) {
return nil, nil
}
+func (m *mockUserRepo) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
+ return map[int64]*time.Time{}, nil
+}
+func (m *mockUserRepo) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) {
+ return nil, nil
+}
func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil }
func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil }
--
GitLab
From 4f6966d7b3979809da9238498bf403ced840d20b Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 00:05:42 +0800
Subject: [PATCH 076/261] frontend: route wechat oauth entry by public settings
---
frontend/src/api/auth.ts | 55 ++++++
.../components/auth/WechatOAuthSection.vue | 62 ++++++-
.../auth/__tests__/WechatOAuthSection.spec.ts | 170 ++++++++++++++++--
3 files changed, 261 insertions(+), 26 deletions(-)
diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts
index 6a2feb87..6c877d76 100644
--- a/frontend/src/api/auth.ts
+++ b/frontend/src/api/auth.ts
@@ -349,6 +349,61 @@ export async function getPublicSettings(): Promise {
return data
}
+export type WeChatOAuthMode = 'open' | 'mp'
+export type WeChatOAuthUnavailableReason =
+ | 'not_configured'
+ | 'external_browser_required'
+ | 'wechat_browser_required'
+
+export interface ResolvedWeChatOAuthStart {
+ mode: WeChatOAuthMode | null
+ openEnabled: boolean
+ mpEnabled: boolean
+ isWeChatBrowser: boolean
+ unavailableReason: WeChatOAuthUnavailableReason | null
+}
+
+type WeChatOAuthPublicSettings = {
+ wechat_oauth_enabled?: boolean
+ wechat_oauth_open_enabled?: boolean
+ wechat_oauth_mp_enabled?: boolean
+}
+
+export function resolveWeChatOAuthStart(
+ settings: WeChatOAuthPublicSettings | null | undefined,
+ userAgent?: string
+): ResolvedWeChatOAuthStart {
+ const normalizedUserAgent = (userAgent
+ ?? (typeof navigator !== 'undefined' ? navigator.userAgent : '')
+ ?? '').trim()
+ const isWeChatBrowser = /MicroMessenger/i.test(normalizedUserAgent)
+ const legacyEnabled = settings?.wechat_oauth_enabled ?? false
+ const openEnabled = typeof settings?.wechat_oauth_open_enabled === 'boolean'
+ ? settings.wechat_oauth_open_enabled
+ : legacyEnabled
+ const mpEnabled = typeof settings?.wechat_oauth_mp_enabled === 'boolean'
+ ? settings.wechat_oauth_mp_enabled
+ : legacyEnabled
+
+ if (isWeChatBrowser) {
+ if (mpEnabled) {
+ return { mode: 'mp', openEnabled, mpEnabled, isWeChatBrowser, unavailableReason: null }
+ }
+ if (openEnabled) {
+ return { mode: null, openEnabled, mpEnabled, isWeChatBrowser, unavailableReason: 'external_browser_required' }
+ }
+ return { mode: null, openEnabled, mpEnabled, isWeChatBrowser, unavailableReason: 'not_configured' }
+ }
+
+ if (openEnabled) {
+ return { mode: 'open', openEnabled, mpEnabled, isWeChatBrowser, unavailableReason: null }
+ }
+ if (mpEnabled) {
+ return { mode: null, openEnabled, mpEnabled, isWeChatBrowser, unavailableReason: 'wechat_browser_required' }
+ }
+ return { mode: null, openEnabled, mpEnabled, isWeChatBrowser, unavailableReason: 'not_configured' }
+}
+
/**
* Send verification code to email
* @param request - Email and optional Turnstile token
diff --git a/frontend/src/components/auth/WechatOAuthSection.vue b/frontend/src/components/auth/WechatOAuthSection.vue
index 94e20222..01bbd180 100644
--- a/frontend/src/components/auth/WechatOAuthSection.vue
+++ b/frontend/src/components/auth/WechatOAuthSection.vue
@@ -1,6 +1,6 @@
-
+
@@ -9,6 +9,14 @@
{{ t('auth.oidc.signIn', { providerName }) }}
+
+ {{ disabledHint }}
+
+
@@ -20,33 +28,69 @@
diff --git a/frontend/src/views/admin/__tests__/AuthIdentityMigrationReportsView.spec.ts b/frontend/src/views/admin/__tests__/AuthIdentityMigrationReportsView.spec.ts
new file mode 100644
index 00000000..5e6b0ae0
--- /dev/null
+++ b/frontend/src/views/admin/__tests__/AuthIdentityMigrationReportsView.spec.ts
@@ -0,0 +1,243 @@
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+import { flushPromises, mount } from '@vue/test-utils'
+import { defineComponent, h } from 'vue'
+
+import AuthIdentityMigrationReportsView from '../AuthIdentityMigrationReportsView.vue'
+
+const { getAuthIdentityMigrationReportSummary, listAuthIdentityMigrationReports, resolveAuthIdentityMigrationReport } = vi.hoisted(() => ({
+ getAuthIdentityMigrationReportSummary: vi.fn(),
+ listAuthIdentityMigrationReports: vi.fn(),
+ resolveAuthIdentityMigrationReport: vi.fn(),
+}))
+
+const { showError, showSuccess } = vi.hoisted(() => ({
+ showError: vi.fn(),
+ showSuccess: vi.fn(),
+}))
+
+vi.mock('@/api/admin', () => ({
+ adminAPI: {
+ users: {
+ getAuthIdentityMigrationReportSummary,
+ listAuthIdentityMigrationReports,
+ resolveAuthIdentityMigrationReport,
+ },
+ },
+}))
+
+vi.mock('@/stores/app', () => ({
+ useAppStore: () => ({
+ showError,
+ showSuccess,
+ }),
+}))
+
+vi.mock('vue-i18n', async () => {
+ const actual = await vi.importActual('vue-i18n')
+ return {
+ ...actual,
+ useI18n: () => ({
+ locale: { value: 'en' },
+ t: (key: string) => key,
+ }),
+ }
+})
+
+vi.mock('@/utils/format', () => ({
+ formatDateTime: (value: string | null | undefined) => value ?? '',
+}))
+
+const sampleReport = {
+ id: 1,
+ report_type: 'oidc_synthetic_email_requires_manual_recovery',
+ report_key: 'legacy@example.invalid',
+ details: {
+ user_id: 42,
+ legacy_email: 'legacy@example.invalid',
+ provider_key: 'https://issuer.example',
+ provider_subject: 'subject-123',
+ },
+ created_at: '2026-04-20T01:02:03Z',
+ resolved_at: null,
+ resolved_by_user_id: null,
+ resolution_note: '',
+}
+
+const summaryResponse = {
+ total: 2,
+ open_total: 1,
+ resolved_total: 1,
+ by_type: {
+ oidc_synthetic_email_requires_manual_recovery: 2,
+ },
+}
+
+const listResponse = {
+ items: [sampleReport],
+ total: 1,
+ page: 1,
+ page_size: 20,
+ pages: 1,
+}
+
+const AppLayoutStub = defineComponent({
+ setup(_, { slots }) {
+ return () => h('div', slots.default?.())
+ },
+})
+
+const TablePageLayoutStub = defineComponent({
+ setup(_, { slots }) {
+ return () => h('div', [
+ slots.actions?.(),
+ slots.filters?.(),
+ slots.table?.(),
+ slots.default?.(),
+ slots.pagination?.(),
+ ])
+ },
+})
+
+const DataTableStub = defineComponent({
+ props: {
+ columns: { type: Array, default: () => [] },
+ data: { type: Array, default: () => [] },
+ loading: { type: Boolean, default: false },
+ },
+ setup(props, { slots }) {
+ return () => h('div', { 'data-test': 'data-table' }, [
+ props.loading
+ ? h('div', 'loading')
+ : (props.data as Array>).map((row) =>
+ h(
+ 'div',
+ { key: String(row.id ?? row.report_key) },
+ (props.columns as Array<{ key: string }>).map((column) => {
+ const slot = slots[`cell-${column.key}`]
+ return h(
+ 'div',
+ { key: column.key, [`data-test-cell`]: `${String(row.id)}-${column.key}` },
+ slot
+ ? slot({ row, value: row[column.key] })
+ : String(row[column.key] ?? '')
+ )
+ })
+ )
+ ),
+ ])
+ },
+})
+
+const PaginationStub = defineComponent({
+ props: {
+ total: { type: Number, required: true },
+ page: { type: Number, required: true },
+ pageSize: { type: Number, required: true },
+ },
+ emits: ['update:page', 'update:pageSize'],
+ setup(props, { emit }) {
+ return () => h('div', { 'data-test': 'pagination' }, [
+ h('button', {
+ type: 'button',
+ 'data-test': 'next-page',
+ onClick: () => emit('update:page', props.page + 1),
+ }, 'next'),
+ h('button', {
+ type: 'button',
+ 'data-test': 'page-size-50',
+ onClick: () => emit('update:pageSize', 50),
+ }, '50'),
+ ])
+ },
+})
+
+describe('AuthIdentityMigrationReportsView', () => {
+ beforeEach(() => {
+ getAuthIdentityMigrationReportSummary.mockReset()
+ listAuthIdentityMigrationReports.mockReset()
+ resolveAuthIdentityMigrationReport.mockReset()
+ showError.mockReset()
+ showSuccess.mockReset()
+
+ getAuthIdentityMigrationReportSummary.mockResolvedValue(summaryResponse)
+ listAuthIdentityMigrationReports.mockResolvedValue(listResponse)
+ resolveAuthIdentityMigrationReport.mockResolvedValue({
+ ...sampleReport,
+ resolved_at: '2026-04-20T02:00:00Z',
+ resolved_by_user_id: 100,
+ resolution_note: 'resolved by admin',
+ })
+ })
+
+ const mountView = () =>
+ mount(AuthIdentityMigrationReportsView, {
+ global: {
+ stubs: {
+ AppLayout: AppLayoutStub,
+ TablePageLayout: TablePageLayoutStub,
+ DataTable: DataTableStub,
+ Pagination: PaginationStub,
+ Icon: true,
+ },
+ },
+ })
+
+ it('loads summary and first page of reports on mount', async () => {
+ const wrapper = mountView()
+
+ await flushPromises()
+
+ expect(getAuthIdentityMigrationReportSummary).toHaveBeenCalledTimes(1)
+ expect(listAuthIdentityMigrationReports).toHaveBeenCalledWith({
+ page: 1,
+ pageSize: 20,
+ reportType: '',
+ })
+ expect(wrapper.get('[data-test="summary-total"]').text()).toContain('2')
+ expect(wrapper.get('[data-test="summary-open"]').text()).toContain('1')
+ expect(wrapper.get('[data-test="summary-resolved"]').text()).toContain('1')
+ expect(wrapper.text()).toContain('legacy@example.invalid')
+ })
+
+ it('reloads list when the report type filter changes', async () => {
+ const wrapper = mountView()
+
+ await flushPromises()
+
+ listAuthIdentityMigrationReports.mockClear()
+
+ await wrapper.get('[data-test="report-type-filter"]').setValue(
+ 'oidc_synthetic_email_requires_manual_recovery'
+ )
+ await flushPromises()
+
+ expect(listAuthIdentityMigrationReports).toHaveBeenCalledWith({
+ page: 1,
+ pageSize: 20,
+ reportType: 'oidc_synthetic_email_requires_manual_recovery',
+ })
+ })
+
+ it('submits resolve note for the selected report and refreshes data', async () => {
+ const wrapper = mountView()
+
+ await flushPromises()
+
+ getAuthIdentityMigrationReportSummary.mockClear()
+ listAuthIdentityMigrationReports.mockClear()
+
+ await wrapper.get('[data-test="select-report-1"]').trigger('click')
+ await wrapper.get('[data-test="resolution-note"]').setValue('resolved by admin')
+ await wrapper.get('[data-test="resolve-submit"]').trigger('click')
+ await flushPromises()
+
+ expect(resolveAuthIdentityMigrationReport).toHaveBeenCalledWith(1, 'resolved by admin')
+ expect(showSuccess).toHaveBeenCalled()
+ expect(getAuthIdentityMigrationReportSummary).toHaveBeenCalledTimes(1)
+ expect(listAuthIdentityMigrationReports).toHaveBeenCalledWith({
+ page: 1,
+ pageSize: 20,
+ reportType: '',
+ })
+ })
+})
--
GitLab
From 9204145746623d382f2fe5f9f2f44723042942dd Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 00:11:03 +0800
Subject: [PATCH 079/261] Close profile identity and avatar loop
---
backend/internal/handler/auth_handler.go | 1 +
.../internal/handler/auth_linuxdo_oauth.go | 2 +-
.../handler/auth_oauth_pending_flow.go | 15 +-
.../handler/auth_oauth_pending_flow_test.go | 145 ++++++++++++-
backend/internal/handler/auth_oidc_oauth.go | 2 +-
backend/internal/handler/auth_wechat_oauth.go | 2 +-
backend/internal/handler/user_handler.go | 112 +++++++++-
backend/internal/handler/user_handler_test.go | 79 +++++++
backend/internal/service/user_service.go | 75 +++++--
frontend/src/api/user.ts | 1 +
.../user/profile/ProfileInfoCard.vue | 202 +++++++++++++++++-
.../profile/__tests__/ProfileInfoCard.spec.ts | 172 +++++++++++++++
frontend/src/i18n/locales/en.ts | 14 ++
frontend/src/i18n/locales/zh.ts | 14 ++
14 files changed, 801 insertions(+), 35 deletions(-)
create mode 100644 frontend/src/components/user/profile/__tests__/ProfileInfoCard.spec.ts
diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go
index e4697609..b984a436 100644
--- a/backend/internal/handler/auth_handler.go
+++ b/backend/internal/handler/auth_handler.go
@@ -296,6 +296,7 @@ func (h *AuthHandler) Login2FA(c *gin.Context) {
c.Request.Context(),
h.entClient(),
h.authService,
+ h.userService,
pendingSession,
decision,
&user.ID,
diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go
index a0760a3b..175b1e1f 100644
--- a/backend/internal/handler/auth_linuxdo_oauth.go
+++ b/backend/internal/handler/auth_linuxdo_oauth.go
@@ -495,7 +495,7 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
- if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, session, decision, &user.ID); err != nil {
+ if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, &user.ID); err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
return
}
diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go
index 1b3c1380..94186858 100644
--- a/backend/internal/handler/auth_oauth_pending_flow.go
+++ b/backend/internal/handler/auth_oauth_pending_flow.go
@@ -852,6 +852,7 @@ func applyPendingOAuthBinding(
ctx context.Context,
client *dbent.Client,
authService *service.AuthService,
+ userService *service.UserService,
session *dbent.PendingAuthSession,
decision *dbent.IdentityAdoptionDecision,
overrideUserID *int64,
@@ -938,6 +939,12 @@ func applyPendingOAuthBinding(
}
}
+ if decision != nil && decision.AdoptAvatar && adoptedAvatarURL != "" && userService != nil {
+ if _, err := userService.SetAvatar(txCtx, targetUserID, adoptedAvatarURL); err != nil {
+ return err
+ }
+ }
+
return tx.Commit()
}
@@ -945,6 +952,7 @@ func applyPendingOAuthAdoption(
ctx context.Context,
client *dbent.Client,
authService *service.AuthService,
+ userService *service.UserService,
session *dbent.PendingAuthSession,
decision *dbent.IdentityAdoptionDecision,
overrideUserID *int64,
@@ -953,6 +961,7 @@ func applyPendingOAuthAdoption(
ctx,
client,
authService,
+ userService,
session,
decision,
overrideUserID,
@@ -1092,7 +1101,7 @@ func (h *AuthHandler) bindPendingOAuthLogin(c *gin.Context, provider string) {
})
return
}
- if err := applyPendingOAuthBinding(c.Request.Context(), h.entClient(), h.authService, session, decision, &user.ID, true, true); err != nil {
+ if err := applyPendingOAuthBinding(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, &user.ID, true, true); err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
return
}
@@ -1188,7 +1197,7 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
response.ErrorFrom(c, err)
return
}
- if err := applyPendingOAuthBinding(c.Request.Context(), client, h.authService, session, decision, &user.ID, true, false); err != nil {
+ if err := applyPendingOAuthBinding(c.Request.Context(), client, h.authService, h.userService, session, decision, &user.ID, true, false); err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
return
}
@@ -1278,7 +1287,7 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
- if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, session, decision, session.TargetUserID); err != nil {
+ if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, session.TargetUserID); err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
return
}
diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go
index 8c468fdc..2521186e 100644
--- a/backend/internal/handler/auth_oauth_pending_flow_test.go
+++ b/backend/internal/handler/auth_oauth_pending_flow_test.go
@@ -152,6 +152,11 @@ func TestExchangePendingOAuthCompletionPreviewThenFinalizeAppliesAdoptionDecisio
require.Equal(t, "Alice Example", identity.Metadata["display_name"])
require.Equal(t, "https://cdn.example/alice.png", identity.Metadata["avatar_url"])
+ avatar := loadUserAvatarRecord(t, client, userEntity.ID)
+ require.NotNil(t, avatar)
+ require.Equal(t, "remote_url", avatar.StorageProvider)
+ require.Equal(t, "https://cdn.example/alice.png", avatar.URL)
+
decision, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
Only(ctx)
@@ -1242,6 +1247,18 @@ CREATE TABLE IF NOT EXISTS user_provider_default_grants (
UNIQUE(user_id, provider_type, grant_reason)
)`)
require.NoError(t, err)
+ _, err = db.Exec(`
+CREATE TABLE IF NOT EXISTS user_avatars (
+ user_id INTEGER PRIMARY KEY,
+ storage_provider TEXT NOT NULL,
+ storage_key TEXT NOT NULL DEFAULT '',
+ url TEXT NOT NULL,
+ content_type TEXT NOT NULL DEFAULT '',
+ byte_size INTEGER NOT NULL DEFAULT 0,
+ sha256 TEXT NOT NULL DEFAULT '',
+ updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
+)`)
+ require.NoError(t, err)
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
@@ -1492,6 +1509,35 @@ func decodeJSONBody(t *testing.T, recorder *httptest.ResponseRecorder) map[strin
return payload
}
+type oauthPendingFlowAvatarRecord struct {
+ StorageProvider string
+ URL string
+}
+
+func loadUserAvatarRecord(t *testing.T, client *dbent.Client, userID int64) *oauthPendingFlowAvatarRecord {
+ t.Helper()
+
+ var rows entsql.Rows
+ err := client.Driver().Query(
+ context.Background(),
+ `SELECT storage_provider, url FROM user_avatars WHERE user_id = ?`,
+ []any{userID},
+ &rows,
+ )
+ require.NoError(t, err)
+ defer rows.Close()
+
+ if !rows.Next() {
+ require.NoError(t, rows.Err())
+ return nil
+ }
+
+ var record oauthPendingFlowAvatarRecord
+ require.NoError(t, rows.Scan(&record.StorageProvider, &record.URL))
+ require.NoError(t, rows.Err())
+ return &record
+}
+
func countProviderGrantRecords(
t *testing.T,
client *dbent.Client,
@@ -1604,16 +1650,95 @@ func (r *oauthPendingFlowUserRepo) Delete(ctx context.Context, id int64) error {
return r.client.User.DeleteOneID(id).Exec(ctx)
}
-func (r *oauthPendingFlowUserRepo) GetUserAvatar(context.Context, int64) (*service.UserAvatar, error) {
- return nil, nil
+func (r *oauthPendingFlowUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*service.UserAvatar, error) {
+ driver := r.client.Driver()
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ driver = tx.Client().Driver()
+ }
+
+ var rows entsql.Rows
+ if err := driver.Query(
+ ctx,
+ `SELECT storage_provider, storage_key, url, content_type, byte_size, sha256 FROM user_avatars WHERE user_id = ?`,
+ []any{userID},
+ &rows,
+ ); err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ if !rows.Next() {
+ return nil, rows.Err()
+ }
+
+ var avatar service.UserAvatar
+ if err := rows.Scan(
+ &avatar.StorageProvider,
+ &avatar.StorageKey,
+ &avatar.URL,
+ &avatar.ContentType,
+ &avatar.ByteSize,
+ &avatar.SHA256,
+ ); err != nil {
+ return nil, err
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return &avatar, nil
}
-func (r *oauthPendingFlowUserRepo) UpsertUserAvatar(context.Context, int64, service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
- panic("unexpected UpsertUserAvatar call")
+func (r *oauthPendingFlowUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
+ driver := r.client.Driver()
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ driver = tx.Client().Driver()
+ }
+
+ var result entsql.Result
+ if err := driver.Exec(
+ ctx,
+ `INSERT INTO user_avatars (user_id, storage_provider, storage_key, url, content_type, byte_size, sha256, updated_at)
+VALUES (?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP)
+ON CONFLICT(user_id) DO UPDATE SET
+ storage_provider = excluded.storage_provider,
+ storage_key = excluded.storage_key,
+ url = excluded.url,
+ content_type = excluded.content_type,
+ byte_size = excluded.byte_size,
+ sha256 = excluded.sha256,
+ updated_at = CURRENT_TIMESTAMP`,
+ []any{
+ userID,
+ input.StorageProvider,
+ input.StorageKey,
+ input.URL,
+ input.ContentType,
+ input.ByteSize,
+ input.SHA256,
+ },
+ &result,
+ ); err != nil {
+ return nil, err
+ }
+
+ return &service.UserAvatar{
+ StorageProvider: input.StorageProvider,
+ StorageKey: input.StorageKey,
+ URL: input.URL,
+ ContentType: input.ContentType,
+ ByteSize: input.ByteSize,
+ SHA256: input.SHA256,
+ }, nil
}
-func (r *oauthPendingFlowUserRepo) DeleteUserAvatar(context.Context, int64) error {
- return nil
+func (r *oauthPendingFlowUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error {
+ driver := r.client.Driver()
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ driver = tx.Client().Driver()
+ }
+
+ var result entsql.Result
+ return driver.Exec(ctx, `DELETE FROM user_avatars WHERE user_id = ?`, []any{userID}, &result)
}
func (r *oauthPendingFlowUserRepo) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
@@ -1636,6 +1761,14 @@ func (r *oauthPendingFlowUserRepo) UpdateConcurrency(context.Context, int64, int
panic("unexpected UpdateConcurrency call")
}
+func (r *oauthPendingFlowUserRepo) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
+ return map[int64]*time.Time{}, nil
+}
+
+func (r *oauthPendingFlowUserRepo) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) {
+ return nil, nil
+}
+
func (r *oauthPendingFlowUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) {
count, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Count(ctx)
return count > 0, err
diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go
index 70424ec5..0f9f1895 100644
--- a/backend/internal/handler/auth_oidc_oauth.go
+++ b/backend/internal/handler/auth_oidc_oauth.go
@@ -537,7 +537,7 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
- if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, session, decision, &user.ID); err != nil {
+ if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, &user.ID); err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
return
}
diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go
index 4d25a763..b078b804 100644
--- a/backend/internal/handler/auth_wechat_oauth.go
+++ b/backend/internal/handler/auth_wechat_oauth.go
@@ -517,7 +517,7 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
- if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, session, decision, &user.ID); err != nil {
+ if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, &user.ID); err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
return
}
diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go
index 9dcff828..b1ade5c0 100644
--- a/backend/internal/handler/user_handler.go
+++ b/backend/internal/handler/user_handler.go
@@ -2,6 +2,7 @@ package handler
import (
"context"
+ "strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
@@ -43,8 +44,24 @@ type UpdateProfileRequest struct {
type userProfileResponse struct {
dto.User
- AvatarURL string `json:"avatar_url,omitempty"`
- Identities service.UserIdentitySummarySet `json:"identities"`
+ AvatarURL string `json:"avatar_url,omitempty"`
+ AvatarSource *userProfileSourceContext `json:"avatar_source,omitempty"`
+ UsernameSource *userProfileSourceContext `json:"username_source,omitempty"`
+ DisplayNameSource *userProfileSourceContext `json:"display_name_source,omitempty"`
+ NicknameSource *userProfileSourceContext `json:"nickname_source,omitempty"`
+ ProfileSources map[string]*userProfileSourceContext `json:"profile_sources,omitempty"`
+ Identities service.UserIdentitySummarySet `json:"identities"`
+ AuthBindings map[string]service.UserIdentitySummary `json:"auth_bindings"`
+ IdentityBindings map[string]service.UserIdentitySummary `json:"identity_bindings"`
+ EmailBound bool `json:"email_bound"`
+ LinuxDoBound bool `json:"linuxdo_bound"`
+ OIDCBound bool `json:"oidc_bound"`
+ WeChatBound bool `json:"wechat_bound"`
+}
+
+type userProfileSourceContext struct {
+ Provider string `json:"provider,omitempty"`
+ Source string `json:"source,omitempty"`
}
// GetProfile handles getting user profile
@@ -335,9 +352,94 @@ func userProfileResponseFromService(user *service.User, identities service.UserI
if base == nil {
return userProfileResponse{}
}
+ bindings := userProfileBindingMap(identities)
+ profileSources, avatarSource, usernameSource := inferUserProfileSources(user, identities)
return userProfileResponse{
- User: *base,
- AvatarURL: user.AvatarURL,
- Identities: identities,
+ User: *base,
+ AvatarURL: user.AvatarURL,
+ AvatarSource: avatarSource,
+ UsernameSource: usernameSource,
+ DisplayNameSource: usernameSource,
+ NicknameSource: usernameSource,
+ ProfileSources: profileSources,
+ Identities: identities,
+ AuthBindings: bindings,
+ IdentityBindings: bindings,
+ EmailBound: identities.Email.Bound,
+ LinuxDoBound: identities.LinuxDo.Bound,
+ OIDCBound: identities.OIDC.Bound,
+ WeChatBound: identities.WeChat.Bound,
+ }
+}
+
+func userProfileBindingMap(identities service.UserIdentitySummarySet) map[string]service.UserIdentitySummary {
+ return map[string]service.UserIdentitySummary{
+ "email": identities.Email,
+ "linuxdo": identities.LinuxDo,
+ "oidc": identities.OIDC,
+ "wechat": identities.WeChat,
+ }
+}
+
+func inferUserProfileSources(user *service.User, identities service.UserIdentitySummarySet) (
+ map[string]*userProfileSourceContext,
+ *userProfileSourceContext,
+ *userProfileSourceContext,
+) {
+ if user == nil {
+ return nil, nil, nil
+ }
+
+ thirdParty := thirdPartyIdentityProviders(identities)
+ var avatarSource *userProfileSourceContext
+ if strings.TrimSpace(user.AvatarURL) != "" && len(thirdParty) == 1 {
+ avatarSource = buildUserProfileSourceContext(thirdParty[0].Provider)
+ }
+
+ usernameValue := strings.TrimSpace(user.Username)
+ var usernameSource *userProfileSourceContext
+ for _, summary := range thirdParty {
+ if usernameValue != "" && usernameValue == strings.TrimSpace(summary.DisplayName) {
+ usernameSource = buildUserProfileSourceContext(summary.Provider)
+ break
+ }
+ }
+ if usernameSource == nil && usernameValue != "" && len(thirdParty) == 1 {
+ usernameSource = buildUserProfileSourceContext(thirdParty[0].Provider)
+ }
+
+ profileSources := map[string]*userProfileSourceContext{}
+ if avatarSource != nil {
+ profileSources["avatar"] = avatarSource
+ }
+ if usernameSource != nil {
+ profileSources["username"] = usernameSource
+ profileSources["display_name"] = usernameSource
+ profileSources["nickname"] = usernameSource
+ }
+ if len(profileSources) == 0 {
+ return nil, avatarSource, usernameSource
+ }
+ return profileSources, avatarSource, usernameSource
+}
+
+func thirdPartyIdentityProviders(identities service.UserIdentitySummarySet) []service.UserIdentitySummary {
+ out := make([]service.UserIdentitySummary, 0, 3)
+ for _, summary := range []service.UserIdentitySummary{identities.LinuxDo, identities.OIDC, identities.WeChat} {
+ if summary.Bound {
+ out = append(out, summary)
+ }
+ }
+ return out
+}
+
+func buildUserProfileSourceContext(provider string) *userProfileSourceContext {
+ provider = strings.TrimSpace(provider)
+ if provider == "" {
+ return nil
+ }
+ return &userProfileSourceContext{
+ Provider: provider,
+ Source: provider,
}
}
diff --git a/backend/internal/handler/user_handler_test.go b/backend/internal/handler/user_handler_test.go
index b71846c1..1216f9c4 100644
--- a/backend/internal/handler/user_handler_test.go
+++ b/backend/internal/handler/user_handler_test.go
@@ -92,6 +92,12 @@ func (s *userHandlerRepoStub) RemoveGroupFromAllowedGroups(context.Context, int6
func (s *userHandlerRepoStub) AddGroupToAllowedGroups(context.Context, int64, int64) error {
return nil
}
+func (s *userHandlerRepoStub) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
+ return map[int64]*time.Time{}, nil
+}
+func (s *userHandlerRepoStub) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) {
+ return nil, nil
+}
func (s *userHandlerRepoStub) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
return nil
}
@@ -230,6 +236,79 @@ func TestUserHandlerGetProfileReturnsIdentitySummaries(t *testing.T) {
require.Contains(t, resp.Data.Identities.WeChat.BindStartPath, "/api/v1/auth/oauth/wechat/start")
}
+func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ verifiedAt := time.Date(2026, 4, 20, 8, 30, 0, 0, time.UTC)
+ repo := &userHandlerRepoStub{
+ user: &service.User{
+ ID: 21,
+ Email: "legacy-profile@example.com",
+ Username: "linuxdo-handle",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ AvatarURL: "https://cdn.example.com/linuxdo.png",
+ AvatarSource: "remote_url",
+ },
+ identities: []service.UserAuthIdentityRecord{
+ {
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: "linuxdo-subject-21",
+ VerifiedAt: &verifiedAt,
+ Metadata: map[string]any{
+ "username": "linuxdo-handle",
+ },
+ },
+ },
+ }
+ handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/user/profile", nil)
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 21})
+
+ handler.GetProfile(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data map[string]any `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.Equal(t, true, resp.Data["email_bound"])
+ require.Equal(t, true, resp.Data["linuxdo_bound"])
+ require.Equal(t, false, resp.Data["oidc_bound"])
+ require.Equal(t, false, resp.Data["wechat_bound"])
+ require.Equal(t, "https://cdn.example.com/linuxdo.png", resp.Data["avatar_url"])
+
+ authBindings, ok := resp.Data["auth_bindings"].(map[string]any)
+ require.True(t, ok)
+ linuxdoBinding, ok := authBindings["linuxdo"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, true, linuxdoBinding["bound"])
+ require.Equal(t, "linuxdo", linuxdoBinding["provider"])
+
+ identityBindings, ok := resp.Data["identity_bindings"].(map[string]any)
+ require.True(t, ok)
+ emailBinding, ok := identityBindings["email"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, true, emailBinding["bound"])
+
+ avatarSource, ok := resp.Data["avatar_source"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "linuxdo", avatarSource["provider"])
+
+ profileSources, ok := resp.Data["profile_sources"].(map[string]any)
+ require.True(t, ok)
+ usernameSource, ok := profileSources["username"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "linuxdo", usernameSource["provider"])
+}
+
func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) {
gin.SetMode(gin.TestMode)
diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go
index c52f91bb..c106a3f5 100644
--- a/backend/internal/service/user_service.go
+++ b/backend/internal/service/user_service.go
@@ -65,6 +65,8 @@ type UserRepository interface {
List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UserListFilters) ([]User, *pagination.PaginationResult, error)
+ GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error)
+ GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error)
UpdateBalance(ctx context.Context, id int64, amount float64) error
DeductBalance(ctx context.Context, id int64, amount float64) error
@@ -159,6 +161,33 @@ type userAuthIdentityReader interface {
ListUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error)
}
+type emailAuthIdentitySynchronizer interface {
+ EnsureEmailAuthIdentity(ctx context.Context, userID int64, email string) error
+ ReplaceEmailAuthIdentity(ctx context.Context, userID int64, oldEmail, newEmail string) error
+}
+
+func ensureEmailAuthIdentitySync(ctx context.Context, repo UserRepository, userID int64, email string) error {
+ syncer, ok := repo.(emailAuthIdentitySynchronizer)
+ if !ok {
+ return nil
+ }
+ return syncer.EnsureEmailAuthIdentity(ctx, userID, email)
+}
+
+func replaceEmailAuthIdentitySync(ctx context.Context, repo UserRepository, userID int64, oldEmail, newEmail string) error {
+ oldNormalized := strings.ToLower(strings.TrimSpace(oldEmail))
+ newNormalized := strings.ToLower(strings.TrimSpace(newEmail))
+ if oldNormalized == newNormalized {
+ return nil
+ }
+
+ syncer, ok := repo.(emailAuthIdentitySynchronizer)
+ if !ok {
+ return nil
+ }
+ return syncer.ReplaceEmailAuthIdentity(ctx, userID, oldEmail, newEmail)
+}
+
// ChangePasswordRequest 修改密码请求
type ChangePasswordRequest struct {
CurrentPassword string `json:"current_password"`
@@ -252,6 +281,7 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
return nil, fmt.Errorf("get user: %w", err)
}
oldConcurrency := user.Concurrency
+ oldEmail := user.Email
// 更新字段
if req.Email != nil {
@@ -271,24 +301,11 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
}
if req.AvatarURL != nil {
- avatarValue := strings.TrimSpace(*req.AvatarURL)
- switch {
- case avatarValue == "":
- if err := s.userRepo.DeleteUserAvatar(ctx, userID); err != nil {
- return nil, fmt.Errorf("delete avatar: %w", err)
- }
- applyUserAvatar(user, nil)
- default:
- avatarInput, err := normalizeUserAvatarInput(avatarValue)
- if err != nil {
- return nil, err
- }
- avatar, err := s.userRepo.UpsertUserAvatar(ctx, userID, avatarInput)
- if err != nil {
- return nil, fmt.Errorf("upsert avatar: %w", err)
- }
- applyUserAvatar(user, avatar)
+ avatar, err := s.SetAvatar(ctx, userID, *req.AvatarURL)
+ if err != nil {
+ return nil, err
}
+ applyUserAvatar(user, avatar)
}
if req.Concurrency != nil {
@@ -309,6 +326,9 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, fmt.Errorf("update user: %w", err)
}
+ if err := replaceEmailAuthIdentitySync(ctx, s.userRepo, user.ID, oldEmail, user.Email); err != nil {
+ return nil, fmt.Errorf("sync email auth identity: %w", err)
+ }
if s.authCacheInvalidator != nil && user.Concurrency != oldConcurrency {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
@@ -316,6 +336,27 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
return user, nil
}
+func (s *UserService) SetAvatar(ctx context.Context, userID int64, raw string) (*UserAvatar, error) {
+ avatarValue := strings.TrimSpace(raw)
+ if avatarValue == "" {
+ if err := s.userRepo.DeleteUserAvatar(ctx, userID); err != nil {
+ return nil, fmt.Errorf("delete avatar: %w", err)
+ }
+ return nil, nil
+ }
+
+ avatarInput, err := normalizeUserAvatarInput(avatarValue)
+ if err != nil {
+ return nil, err
+ }
+
+ avatar, err := s.userRepo.UpsertUserAvatar(ctx, userID, avatarInput)
+ if err != nil {
+ return nil, fmt.Errorf("upsert avatar: %w", err)
+ }
+ return avatar, nil
+}
+
func applyUserAvatar(user *User, avatar *UserAvatar) {
if user == nil {
return
diff --git a/frontend/src/api/user.ts b/frontend/src/api/user.ts
index 1f6e4cd9..c7b1e503 100644
--- a/frontend/src/api/user.ts
+++ b/frontend/src/api/user.ts
@@ -23,6 +23,7 @@ export async function getProfile(): Promise {
*/
export async function updateProfile(profile: {
username?: string
+ avatar_url?: string | null
balance_notify_enabled?: boolean
balance_notify_threshold?: number | null
balance_notify_extra_emails?: NotifyEmailEntry[]
diff --git a/frontend/src/components/user/profile/ProfileInfoCard.vue b/frontend/src/components/user/profile/ProfileInfoCard.vue
index e82ae229..d6273431 100644
--- a/frontend/src/components/user/profile/ProfileInfoCard.vue
+++ b/frontend/src/components/user/profile/ProfileInfoCard.vue
@@ -61,6 +61,71 @@
+
+
+
+
+ {{ t('profile.avatar.title') }}
+
+
+ {{ t('profile.avatar.description') }}
+
+
+
+ {{ t('common.delete') }}
+
+
+
+
+
+ {{ t('profile.avatar.inputLabel') }}
+
+
+
+
+
+ {{ t('profile.avatar.uploadAction') }}
+
+
+ {{ t('common.save') }}
+
+
+ {{ t('profile.avatar.uploadHint') }}
+
+
+
+
+
diff --git a/frontend/src/components/user/profile/__tests__/ProfileInfoCard.spec.ts b/frontend/src/components/user/profile/__tests__/ProfileInfoCard.spec.ts
new file mode 100644
index 00000000..238a898e
--- /dev/null
+++ b/frontend/src/components/user/profile/__tests__/ProfileInfoCard.spec.ts
@@ -0,0 +1,172 @@
+import { mount } from '@vue/test-utils'
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+import ProfileInfoCard from '@/components/user/profile/ProfileInfoCard.vue'
+import type { User } from '@/types'
+
+const {
+ updateProfileMock,
+ showSuccessMock,
+ showErrorMock,
+ authStoreState
+} = vi.hoisted(() => ({
+ updateProfileMock: vi.fn(),
+ showSuccessMock: vi.fn(),
+ showErrorMock: vi.fn(),
+ authStoreState: {
+ user: null as User | null
+ }
+}))
+
+vi.mock('@/api', () => ({
+ userAPI: {
+ updateProfile: updateProfileMock
+ }
+}))
+
+vi.mock('@/stores/auth', () => ({
+ useAuthStore: () => authStoreState
+}))
+
+vi.mock('@/stores/app', () => ({
+ useAppStore: () => ({
+ showSuccess: showSuccessMock,
+ showError: showErrorMock
+ })
+}))
+
+vi.mock('@/utils/apiError', () => ({
+ extractApiErrorMessage: (error: unknown) => (error as Error).message || 'request failed'
+}))
+
+vi.mock('vue-i18n', async (importOriginal) => {
+ const actual = await importOriginal()
+ return {
+ ...actual,
+ useI18n: () => ({
+ t: (key: string, params?: Record) => {
+ if (key === 'profile.administrator') return 'Administrator'
+ if (key === 'profile.user') return 'User'
+ if (key === 'profile.avatar.title') return 'Profile avatar'
+ if (key === 'profile.avatar.description') return 'Set avatar by image URL or upload'
+ if (key === 'profile.avatar.inputLabel') return 'Avatar URL or data URL'
+ if (key === 'profile.avatar.inputPlaceholder') return 'https://cdn.example.com/avatar.png'
+ if (key === 'profile.avatar.uploadAction') return 'Upload image'
+ if (key === 'profile.avatar.uploadHint') return 'Images must be 100KB or smaller'
+ if (key === 'profile.avatar.saveSuccess') return 'Avatar updated'
+ if (key === 'profile.avatar.deleteSuccess') return 'Avatar removed'
+ if (key === 'profile.avatar.invalidType') return 'Please choose an image file'
+ if (key === 'profile.avatar.fileTooLarge') return 'Avatar image must be 100KB or smaller'
+ if (key === 'profile.avatar.invalidValue') return 'Enter a valid avatar URL or image data URL'
+ if (key === 'profile.avatar.emptyDeleteHint') return 'Avatar already removed'
+ if (key === 'profile.authBindings.providers.email') return 'Email'
+ if (key === 'profile.authBindings.providers.linuxdo') return 'LinuxDo'
+ if (key === 'profile.authBindings.providers.wechat') return 'WeChat'
+ if (key === 'profile.authBindings.providers.oidc') return params?.providerName || 'OIDC'
+ if (key === 'common.save') return 'Save'
+ if (key === 'common.delete') return 'Delete'
+ return key
+ }
+ })
+ }
+})
+
+function createUser(overrides: Partial = {}): User {
+ return {
+ id: 5,
+ username: 'alice',
+ email: 'alice@example.com',
+ avatar_url: null,
+ role: 'user',
+ balance: 10,
+ concurrency: 2,
+ status: 'active',
+ allowed_groups: null,
+ balance_notify_enabled: true,
+ balance_notify_threshold: null,
+ balance_notify_extra_emails: [],
+ created_at: '2026-04-20T00:00:00Z',
+ updated_at: '2026-04-20T00:00:00Z',
+ ...overrides
+ }
+}
+
+describe('ProfileInfoCard', () => {
+ beforeEach(() => {
+ updateProfileMock.mockReset()
+ showSuccessMock.mockReset()
+ showErrorMock.mockReset()
+ authStoreState.user = null
+ })
+
+ it('saves a remote avatar URL and updates the auth store', async () => {
+ const updatedUser = createUser({ avatar_url: 'https://cdn.example.com/new.png' })
+ updateProfileMock.mockResolvedValue(updatedUser)
+ authStoreState.user = createUser()
+
+ const wrapper = mount(ProfileInfoCard, {
+ props: {
+ user: authStoreState.user
+ },
+ global: {
+ stubs: {
+ Icon: true,
+ ProfileIdentityBindingsSection: true
+ }
+ }
+ })
+
+ await wrapper.get('[data-testid="profile-avatar-input"]').setValue('https://cdn.example.com/new.png')
+ await wrapper.get('[data-testid="profile-avatar-save"]').trigger('click')
+
+ expect(updateProfileMock).toHaveBeenCalledWith({ avatar_url: 'https://cdn.example.com/new.png' })
+ expect(authStoreState.user?.avatar_url).toBe('https://cdn.example.com/new.png')
+ expect(showSuccessMock).toHaveBeenCalledWith('Avatar updated')
+ })
+
+ it('rejects an oversized data URL before sending the request', async () => {
+ authStoreState.user = createUser()
+ const oversized = `data:image/png;base64,${Buffer.from(new Uint8Array(102401)).toString('base64')}`
+
+ const wrapper = mount(ProfileInfoCard, {
+ props: {
+ user: authStoreState.user
+ },
+ global: {
+ stubs: {
+ Icon: true,
+ ProfileIdentityBindingsSection: true
+ }
+ }
+ })
+
+ await wrapper.get('[data-testid="profile-avatar-input"]').setValue(oversized)
+ await wrapper.get('[data-testid="profile-avatar-save"]').trigger('click')
+
+ expect(updateProfileMock).not.toHaveBeenCalled()
+ expect(showErrorMock).toHaveBeenCalledWith('Avatar image must be 100KB or smaller')
+ })
+
+ it('deletes the current avatar', async () => {
+ const updatedUser = createUser({ avatar_url: null })
+ updateProfileMock.mockResolvedValue(updatedUser)
+ authStoreState.user = createUser({ avatar_url: 'https://cdn.example.com/old.png' })
+
+ const wrapper = mount(ProfileInfoCard, {
+ props: {
+ user: authStoreState.user
+ },
+ global: {
+ stubs: {
+ Icon: true,
+ ProfileIdentityBindingsSection: true
+ }
+ }
+ })
+
+ await wrapper.get('[data-testid="profile-avatar-delete"]').trigger('click')
+
+ expect(updateProfileMock).toHaveBeenCalledWith({ avatar_url: '' })
+ expect(authStoreState.user?.avatar_url).toBeNull()
+ expect(showSuccessMock).toHaveBeenCalledWith('Avatar removed')
+ })
+})
diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts
index e17ed616..62958642 100644
--- a/frontend/src/i18n/locales/en.ts
+++ b/frontend/src/i18n/locales/en.ts
@@ -941,6 +941,20 @@ export default {
unverified: 'Unverified',
verified: 'Verified',
},
+ avatar: {
+ title: 'Profile Avatar',
+ description: 'Set your avatar with a remote image URL or upload a small image.',
+ inputLabel: 'Avatar URL or data URL',
+ inputPlaceholder: 'https://cdn.example.com/avatar.png',
+ uploadAction: 'Upload image',
+ uploadHint: 'Images must be 100KB or smaller',
+ saveSuccess: 'Avatar updated',
+ deleteSuccess: 'Avatar removed',
+ invalidType: 'Please choose an image file',
+ fileTooLarge: 'Avatar image must be 100KB or smaller',
+ invalidValue: 'Enter a valid avatar URL or image data URL',
+ emptyDeleteHint: 'Avatar is already empty',
+ },
authBindings: {
title: 'Connected Sign-In Methods',
description: 'View current bindings and connect another provider to this account.',
diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts
index d54c0aba..7b7cdbb4 100644
--- a/frontend/src/i18n/locales/zh.ts
+++ b/frontend/src/i18n/locales/zh.ts
@@ -945,6 +945,20 @@ export default {
unverified: '未验证',
verified: '已验证',
},
+ avatar: {
+ title: '资料头像',
+ description: '支持填写远程图片 URL,或上传不超过 100KB 的头像图片。',
+ inputLabel: '头像 URL 或 data URL',
+ inputPlaceholder: 'https://cdn.example.com/avatar.png',
+ uploadAction: '上传图片',
+ uploadHint: '图片大小需不超过 100KB',
+ saveSuccess: '头像已更新',
+ deleteSuccess: '头像已删除',
+ invalidType: '请选择图片文件',
+ fileTooLarge: '头像图片必须不超过 100KB',
+ invalidValue: '请输入有效的头像 URL 或图片 data URL',
+ emptyDeleteHint: '当前没有可删除的头像',
+ },
authBindings: {
title: '登录方式绑定',
description: '查看当前绑定状态,并将更多第三方登录方式关联到这个账号。',
--
GitLab
From 5d58c7c6fba7e7531ddfcf80b6a8b4915998b4b4 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 00:13:40 +0800
Subject: [PATCH 080/261] Add auth identity legacy backfill and email sync
---
...ntity_legacy_migration_integration_test.go | 206 ++++++++++++++++++
backend/internal/repository/user_repo.go | 102 +++++++++
...er_repo_email_identity_integration_test.go | 86 ++++++++
.../repository/user_repo_integration_test.go | 2 +
backend/internal/service/admin_service.go | 7 +
.../admin_service_email_identity_sync_test.go | 95 ++++++++
.../user_service_email_identity_sync_test.go | 68 ++++++
...auth_identity_legacy_external_backfill.sql | 187 ++++++++++++++++
8 files changed, 753 insertions(+)
create mode 100644 backend/internal/repository/auth_identity_legacy_migration_integration_test.go
create mode 100644 backend/internal/repository/user_repo_email_identity_integration_test.go
create mode 100644 backend/internal/service/admin_service_email_identity_sync_test.go
create mode 100644 backend/internal/service/user_service_email_identity_sync_test.go
create mode 100644 backend/migrations/115_auth_identity_legacy_external_backfill.sql
diff --git a/backend/internal/repository/auth_identity_legacy_migration_integration_test.go b/backend/internal/repository/auth_identity_legacy_migration_integration_test.go
new file mode 100644
index 00000000..6a6312d4
--- /dev/null
+++ b/backend/internal/repository/auth_identity_legacy_migration_integration_test.go
@@ -0,0 +1,206 @@
+//go:build integration
+
+package repository
+
+import (
+ "context"
+ "os"
+ "path/filepath"
+ "strconv"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestAuthIdentityLegacyExternalBackfillMigration(t *testing.T) {
+ tx := testTx(t)
+ ctx := context.Background()
+
+ migrationPath := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql")
+ migrationSQL, err := os.ReadFile(migrationPath)
+ require.NoError(t, err)
+
+ _, err = tx.ExecContext(ctx, `
+CREATE TABLE IF NOT EXISTS user_external_identities (
+ id BIGSERIAL PRIMARY KEY,
+ user_id BIGINT NOT NULL,
+ provider TEXT NOT NULL,
+ provider_user_id TEXT NOT NULL,
+ provider_union_id TEXT NULL,
+ provider_username TEXT NOT NULL DEFAULT '',
+ display_name TEXT NOT NULL DEFAULT '',
+ profile_url TEXT NOT NULL DEFAULT '',
+ avatar_url TEXT NOT NULL DEFAULT '',
+ metadata TEXT NOT NULL DEFAULT '{}',
+ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
+);
+
+ TRUNCATE TABLE
+ auth_identity_channels,
+ auth_identities,
+ auth_identity_migration_reports,
+ user_external_identities,
+ users
+ RESTART IDENTITY;
+`)
+ require.NoError(t, err)
+
+ var linuxDoUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-linuxdo@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&linuxDoUserID))
+
+ var wechatUnionUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-wechat-union@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&wechatUnionUserID))
+
+ var wechatOpenIDOnlyUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-wechat-openid@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&wechatOpenIDOnlyUserID))
+
+ var syntheticAuthIdentityID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO auth_identities (user_id, provider_type, provider_key, provider_subject, metadata)
+VALUES ($1, 'wechat', 'wechat-main', 'openid-synthetic', '{"backfill_source":"synthetic_email"}'::jsonb)
+RETURNING id`, wechatOpenIDOnlyUserID).Scan(&syntheticAuthIdentityID))
+
+ var linuxDoLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'linuxdo', 'linuxdo-user-1', NULL, 'linux-user', 'Linux User', '{"source":"legacy"}')
+RETURNING id
+`, linuxDoUserID).Scan(&linuxDoLegacyID))
+
+ var wechatUnionLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'wechat', 'openid-union-1', 'union-1', 'wechat-union-user', 'WeChat Union User', '{"channel":"oa","appid":"wx-app-1"}')
+RETURNING id
+`, wechatUnionUserID).Scan(&wechatUnionLegacyID))
+
+ var wechatOpenIDLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'wechat', 'openid-only-1', NULL, 'wechat-openid-user', 'WeChat OpenID User', '{"channel":"oa","appid":"wx-app-2"}')
+RETURNING id
+`, wechatOpenIDOnlyUserID).Scan(&wechatOpenIDLegacyID))
+
+ _, err = tx.ExecContext(ctx, string(migrationSQL))
+ require.NoError(t, err)
+
+ var linuxDoCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identities
+WHERE user_id = $1
+ AND provider_type = 'linuxdo'
+ AND provider_key = 'linuxdo'
+ AND provider_subject = 'linuxdo-user-1'
+`, linuxDoUserID).Scan(&linuxDoCount))
+ require.Equal(t, 1, linuxDoCount)
+
+ var wechatSubject string
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT provider_subject
+FROM auth_identities
+WHERE user_id = $1
+ AND provider_type = 'wechat'
+ AND provider_key = 'wechat-main'
+ AND provider_subject = 'union-1'
+`, wechatUnionUserID).Scan(&wechatSubject))
+ require.Equal(t, "union-1", wechatSubject)
+
+ var wechatChannelCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_channels channel
+JOIN auth_identities ai ON ai.id = channel.identity_id
+WHERE ai.user_id = $1
+ AND channel.provider_type = 'wechat'
+ AND channel.provider_key = 'wechat-main'
+ AND channel.channel = 'oa'
+ AND channel.channel_app_id = 'wx-app-1'
+ AND channel.channel_subject = 'openid-union-1'
+`, wechatUnionUserID).Scan(&wechatChannelCount))
+ require.Equal(t, 1, wechatChannelCount)
+
+ var legacyOpenIDOnlyReportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'wechat_openid_only_requires_remediation'
+ AND report_key = $1
+`, "legacy_external_identity:"+strconv.FormatInt(wechatOpenIDLegacyID, 10)).Scan(&legacyOpenIDOnlyReportCount))
+ require.Equal(t, 1, legacyOpenIDOnlyReportCount)
+
+ var syntheticReviewCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'wechat_openid_only_requires_remediation'
+ AND report_key = $1
+`, "synthetic_auth_identity:"+strconv.FormatInt(syntheticAuthIdentityID, 10)).Scan(&syntheticReviewCount))
+ require.Equal(t, 1, syntheticReviewCount)
+
+ var unionLegacyReportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'wechat_openid_only_requires_remediation'
+ AND report_key = $1
+`, "legacy_external_identity:"+strconv.FormatInt(wechatUnionLegacyID, 10)).Scan(&unionLegacyReportCount))
+ require.Zero(t, unionLegacyReportCount)
+ require.NotZero(t, linuxDoLegacyID)
+}
+
+func TestAuthIdentityLegacyExternalBackfillMigration_IsSafeWhenLegacyTableMissing(t *testing.T) {
+ tx := testTx(t)
+ ctx := context.Background()
+
+ migrationPath := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql")
+ migrationSQL, err := os.ReadFile(migrationPath)
+ require.NoError(t, err)
+
+ var beforeCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+`).Scan(&beforeCount))
+
+ _, err = tx.ExecContext(ctx, string(migrationSQL))
+ require.NoError(t, err)
+
+ var afterCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+`).Scan(&afterCount))
+ require.Equal(t, beforeCount, afterCount)
+}
diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go
index b5efd19d..b2190b68 100644
--- a/backend/internal/repository/user_repo.go
+++ b/backend/internal/repository/user_repo.go
@@ -11,6 +11,7 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/apikey"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
dbgroup "github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/predicate"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
@@ -76,6 +77,9 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, created.ID, userIn.AllowedGroups); err != nil {
return err
}
+ if err := ensureEmailAuthIdentityWithClient(ctx, txClient, created.ID, created.Email, "user_repo_create"); err != nil {
+ return err
+ }
if tx != nil {
if err := tx.Commit(); err != nil {
@@ -150,6 +154,11 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
// 已处于外部事务中(ErrTxStarted),复用当前 client 并由调用方负责提交/回滚。
txClient = r.client
}
+ existing, err := clientFromContext(ctx, txClient).User.Get(ctx, userIn.ID)
+ if err != nil {
+ return translatePersistenceError(err, service.ErrUserNotFound, nil)
+ }
+ oldEmail := existing.Email
updateOp := txClient.User.UpdateOneID(userIn.ID).
SetEmail(userIn.Email).
@@ -185,6 +194,9 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, updated.ID, userIn.AllowedGroups); err != nil {
return err
}
+ if err := replaceEmailAuthIdentityWithClient(ctx, txClient, updated.ID, oldEmail, updated.Email, "user_repo_update"); err != nil {
+ return err
+ }
if tx != nil {
if err := tx.Commit(); err != nil {
@@ -196,6 +208,96 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
return nil
}
+func (r *userRepository) EnsureEmailAuthIdentity(ctx context.Context, userID int64, email string) error {
+ return ensureEmailAuthIdentityWithClient(ctx, r.client, userID, email, "service_dual_write")
+}
+
+func (r *userRepository) ReplaceEmailAuthIdentity(ctx context.Context, userID int64, oldEmail, newEmail string) error {
+ return replaceEmailAuthIdentityWithClient(ctx, r.client, userID, oldEmail, newEmail, "service_dual_write")
+}
+
+func ensureEmailAuthIdentityWithClient(ctx context.Context, client *dbent.Client, userID int64, email string, source string) error {
+ client = clientFromContext(ctx, client)
+ if client == nil || userID <= 0 {
+ return nil
+ }
+
+ subject := normalizeEmailAuthIdentitySubject(email)
+ if subject == "" {
+ return nil
+ }
+
+ if err := client.AuthIdentity.Create().
+ SetUserID(userID).
+ SetProviderType("email").
+ SetProviderKey("email").
+ SetProviderSubject(subject).
+ SetVerifiedAt(time.Now().UTC()).
+ SetMetadata(map[string]any{"source": source}).
+ OnConflictColumns(
+ authidentity.FieldProviderType,
+ authidentity.FieldProviderKey,
+ authidentity.FieldProviderSubject,
+ ).
+ DoNothing().
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ(subject),
+ ).
+ Only(ctx)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil
+ }
+ return err
+ }
+ if identity.UserID != userID {
+ return ErrAuthIdentityOwnershipConflict
+ }
+ return nil
+}
+
+func replaceEmailAuthIdentityWithClient(ctx context.Context, client *dbent.Client, userID int64, oldEmail, newEmail string, source string) error {
+ newSubject := normalizeEmailAuthIdentitySubject(newEmail)
+ if err := ensureEmailAuthIdentityWithClient(ctx, client, userID, newEmail, source); err != nil {
+ return err
+ }
+
+ oldSubject := normalizeEmailAuthIdentitySubject(oldEmail)
+ if oldSubject == "" || oldSubject == newSubject {
+ return nil
+ }
+
+ _, err := clientFromContext(ctx, client).AuthIdentity.Delete().
+ Where(
+ authidentity.UserIDEQ(userID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ(oldSubject),
+ ).
+ Exec(ctx)
+ return err
+}
+
+func normalizeEmailAuthIdentitySubject(email string) string {
+ normalized := strings.ToLower(strings.TrimSpace(email))
+ if normalized == "" {
+ return ""
+ }
+ if strings.HasSuffix(normalized, service.LinuxDoConnectSyntheticEmailDomain) ||
+ strings.HasSuffix(normalized, service.OIDCConnectSyntheticEmailDomain) ||
+ strings.HasSuffix(normalized, service.WeChatConnectSyntheticEmailDomain) {
+ return ""
+ }
+ return normalized
+}
+
func (r *userRepository) Delete(ctx context.Context, id int64) error {
affected, err := r.client.User.Delete().Where(dbuser.IDEQ(id)).Exec(ctx)
if err != nil {
diff --git a/backend/internal/repository/user_repo_email_identity_integration_test.go b/backend/internal/repository/user_repo_email_identity_integration_test.go
new file mode 100644
index 00000000..fddd82c5
--- /dev/null
+++ b/backend/internal/repository/user_repo_email_identity_integration_test.go
@@ -0,0 +1,86 @@
+//go:build integration
+
+package repository
+
+import (
+ "context"
+
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+)
+
+func (s *UserRepoSuite) TestCreate_CreatesEmailAuthIdentityForNormalEmail() {
+ user := &service.User{
+ Email: "repo-create@example.com",
+ PasswordHash: "test-password-hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ Concurrency: 2,
+ }
+
+ s.Require().NoError(s.repo.Create(s.ctx, user))
+
+ identity, err := s.client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("repo-create@example.com"),
+ ).
+ Only(s.ctx)
+ s.Require().NoError(err)
+ s.Require().Equal(user.ID, identity.UserID)
+}
+
+func (s *UserRepoSuite) TestCreate_SkipsEmailAuthIdentityForSyntheticLinuxDoEmail() {
+ user := &service.User{
+ Email: "linuxdo-legacy-user@linuxdo-connect.invalid",
+ PasswordHash: "test-password-hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ Concurrency: 2,
+ }
+
+ s.Require().NoError(s.repo.Create(s.ctx, user))
+
+ count, err := s.client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ ).
+ Count(s.ctx)
+ s.Require().NoError(err)
+ s.Require().Zero(count)
+}
+
+func (s *UserRepoSuite) TestUpdate_ReplacesEmailAuthIdentityWhenEmailChanges() {
+ user := s.mustCreateUser(&service.User{
+ Email: "before-update@example.com",
+ })
+
+ user.Email = "after-update@example.com"
+ s.Require().NoError(s.repo.Update(s.ctx, user))
+
+ newIdentity, err := s.client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("after-update@example.com"),
+ ).
+ Only(s.ctx)
+ s.Require().NoError(err)
+ s.Require().Equal(user.ID, newIdentity.UserID)
+
+ oldCount, err := s.client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("before-update@example.com"),
+ ).
+ Count(context.Background())
+ s.Require().NoError(err)
+ s.Require().Zero(oldCount)
+}
diff --git a/backend/internal/repository/user_repo_integration_test.go b/backend/internal/repository/user_repo_integration_test.go
index f5d0f9ff..07fb0598 100644
--- a/backend/internal/repository/user_repo_integration_test.go
+++ b/backend/internal/repository/user_repo_integration_test.go
@@ -26,6 +26,8 @@ func (s *UserRepoSuite) SetupTest() {
s.repo = newUserRepositoryWithSQL(s.client, integrationDB)
// 清理测试数据,确保每个测试从干净状态开始
+ _, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM auth_identity_channels")
+ _, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM auth_identities")
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM user_subscriptions")
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM user_allowed_groups")
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM users")
diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go
index 3490374e..79840e5b 100644
--- a/backend/internal/service/admin_service.go
+++ b/backend/internal/service/admin_service.go
@@ -630,6 +630,9 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu
if err := s.userRepo.Create(ctx, user); err != nil {
return nil, err
}
+ if err := ensureEmailAuthIdentitySync(ctx, s.userRepo, user.ID, user.Email); err != nil {
+ return nil, fmt.Errorf("sync email auth identity: %w", err)
+ }
s.assignDefaultSubscriptions(ctx, user.ID)
return user, nil
}
@@ -665,6 +668,7 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
oldConcurrency := user.Concurrency
oldStatus := user.Status
oldRole := user.Role
+ oldEmail := user.Email
if input.Email != "" {
user.Email = input.Email
@@ -697,6 +701,9 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, err
}
+ if err := replaceEmailAuthIdentitySync(ctx, s.userRepo, user.ID, oldEmail, user.Email); err != nil {
+ return nil, fmt.Errorf("sync email auth identity: %w", err)
+ }
// 同步用户专属分组倍率
if input.GroupRates != nil && s.userGroupRateRepo != nil {
diff --git a/backend/internal/service/admin_service_email_identity_sync_test.go b/backend/internal/service/admin_service_email_identity_sync_test.go
new file mode 100644
index 00000000..3f3d867c
--- /dev/null
+++ b/backend/internal/service/admin_service_email_identity_sync_test.go
@@ -0,0 +1,95 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+type ensureEmailCall struct {
+ userID int64
+ email string
+}
+
+type replaceEmailCall struct {
+ userID int64
+ oldEmail string
+ newEmail string
+}
+
+type emailSyncUserRepoStub struct {
+ *userRepoStub
+ ensureCalls []ensureEmailCall
+ replaceCalls []replaceEmailCall
+}
+
+func (s *emailSyncUserRepoStub) EnsureEmailAuthIdentity(_ context.Context, userID int64, email string) error {
+ s.ensureCalls = append(s.ensureCalls, ensureEmailCall{userID: userID, email: email})
+ return nil
+}
+
+func (s *emailSyncUserRepoStub) ReplaceEmailAuthIdentity(_ context.Context, userID int64, oldEmail, newEmail string) error {
+ s.replaceCalls = append(s.replaceCalls, replaceEmailCall{
+ userID: userID,
+ oldEmail: oldEmail,
+ newEmail: newEmail,
+ })
+ return nil
+}
+
+func (s *emailSyncUserRepoStub) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
+ return map[int64]*time.Time{}, nil
+}
+
+func (s *emailSyncUserRepoStub) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) {
+ return nil, nil
+}
+
+func TestAdminService_CreateUser_EnsuresEmailAuthIdentity(t *testing.T) {
+ repo := &emailSyncUserRepoStub{userRepoStub: &userRepoStub{nextID: 55}}
+ svc := &adminServiceImpl{userRepo: repo}
+
+ user, err := svc.CreateUser(context.Background(), &CreateUserInput{
+ Email: "admin-created@example.com",
+ Password: "strong-pass",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, user)
+ require.Equal(t, []ensureEmailCall{{
+ userID: 55,
+ email: "admin-created@example.com",
+ }}, repo.ensureCalls)
+ require.Empty(t, repo.replaceCalls)
+}
+
+func TestAdminService_UpdateUser_ReplacesEmailAuthIdentity(t *testing.T) {
+ repo := &emailSyncUserRepoStub{
+ userRepoStub: &userRepoStub{
+ user: &User{
+ ID: 91,
+ Email: "before@example.com",
+ Role: RoleUser,
+ Status: StatusActive,
+ Concurrency: 3,
+ },
+ },
+ }
+ svc := &adminServiceImpl{userRepo: repo}
+
+ updated, err := svc.UpdateUser(context.Background(), 91, &UpdateUserInput{
+ Email: "after@example.com",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, updated)
+ require.Equal(t, "after@example.com", updated.Email)
+ require.Equal(t, []replaceEmailCall{{
+ userID: 91,
+ oldEmail: "before@example.com",
+ newEmail: "after@example.com",
+ }}, repo.replaceCalls)
+ require.Empty(t, repo.ensureCalls)
+}
diff --git a/backend/internal/service/user_service_email_identity_sync_test.go b/backend/internal/service/user_service_email_identity_sync_test.go
new file mode 100644
index 00000000..3950df8b
--- /dev/null
+++ b/backend/internal/service/user_service_email_identity_sync_test.go
@@ -0,0 +1,68 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+type emailSyncMockUserRepo struct {
+ *mockUserRepo
+ ensureCalls []ensureEmailCall
+ replaceCalls []replaceEmailCall
+}
+
+func (m *emailSyncMockUserRepo) EnsureEmailAuthIdentity(_ context.Context, userID int64, email string) error {
+ m.ensureCalls = append(m.ensureCalls, ensureEmailCall{userID: userID, email: email})
+ return nil
+}
+
+func (m *emailSyncMockUserRepo) ReplaceEmailAuthIdentity(_ context.Context, userID int64, oldEmail, newEmail string) error {
+ m.replaceCalls = append(m.replaceCalls, replaceEmailCall{
+ userID: userID,
+ oldEmail: oldEmail,
+ newEmail: newEmail,
+ })
+ return nil
+}
+
+func (m *emailSyncMockUserRepo) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
+ return map[int64]*time.Time{}, nil
+}
+
+func (m *emailSyncMockUserRepo) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) {
+ return nil, nil
+}
+
+func TestUpdateProfile_ReplacesEmailAuthIdentityWhenEmailChanges(t *testing.T) {
+ repo := &emailSyncMockUserRepo{
+ mockUserRepo: &mockUserRepo{
+ getByIDUser: &User{
+ ID: 19,
+ Email: "profile-before@example.com",
+ Username: "tester",
+ Concurrency: 2,
+ },
+ },
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ newEmail := "profile-after@example.com"
+ updated, err := svc.UpdateProfile(context.Background(), 19, UpdateProfileRequest{
+ Email: &newEmail,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, updated)
+ require.Equal(t, newEmail, updated.Email)
+ require.Equal(t, 1, repo.updateCalls)
+ require.Equal(t, []replaceEmailCall{{
+ userID: 19,
+ oldEmail: "profile-before@example.com",
+ newEmail: "profile-after@example.com",
+ }}, repo.replaceCalls)
+ require.Empty(t, repo.ensureCalls)
+}
diff --git a/backend/migrations/115_auth_identity_legacy_external_backfill.sql b/backend/migrations/115_auth_identity_legacy_external_backfill.sql
new file mode 100644
index 00000000..f4a13c36
--- /dev/null
+++ b/backend/migrations/115_auth_identity_legacy_external_backfill.sql
@@ -0,0 +1,187 @@
+DO $$
+BEGIN
+ IF to_regclass('public.user_external_identities') IS NULL THEN
+ RETURN;
+ END IF;
+
+ EXECUTE $sql$
+INSERT INTO auth_identities (
+ user_id,
+ provider_type,
+ provider_key,
+ provider_subject,
+ verified_at,
+ metadata
+)
+SELECT
+ legacy.user_id,
+ 'linuxdo',
+ 'linuxdo',
+ legacy.provider_user_id,
+ COALESCE(legacy.updated_at, legacy.created_at, NOW()),
+ legacy.metadata_json || jsonb_build_object(
+ 'legacy_identity_id', legacy.id,
+ 'provider_user_id', legacy.provider_user_id,
+ 'provider_username', legacy.provider_username,
+ 'display_name', legacy.display_name,
+ 'migration', '115_auth_identity_legacy_external_backfill'
+ )
+FROM (
+ SELECT
+ uei.id,
+ uei.user_id,
+ BTRIM(uei.provider_user_id) AS provider_user_id,
+ BTRIM(uei.provider_username) AS provider_username,
+ BTRIM(uei.display_name) AS display_name,
+ COALESCE(NULLIF(BTRIM(COALESCE(uei.metadata, '')), '')::jsonb, '{}'::jsonb) AS metadata_json,
+ uei.created_at,
+ uei.updated_at
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo'
+ AND BTRIM(COALESCE(uei.provider_user_id, '')) <> ''
+) AS legacy
+ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
+$sql$;
+
+ EXECUTE $sql$
+INSERT INTO auth_identities (
+ user_id,
+ provider_type,
+ provider_key,
+ provider_subject,
+ verified_at,
+ metadata
+)
+SELECT
+ legacy.user_id,
+ 'wechat',
+ 'wechat-main',
+ legacy.provider_union_id,
+ COALESCE(legacy.updated_at, legacy.created_at, NOW()),
+ legacy.metadata_json || jsonb_build_object(
+ 'legacy_identity_id', legacy.id,
+ 'openid', legacy.provider_user_id,
+ 'unionid', legacy.provider_union_id,
+ 'provider_user_id', legacy.provider_user_id,
+ 'provider_union_id', legacy.provider_union_id,
+ 'provider_username', legacy.provider_username,
+ 'display_name', legacy.display_name,
+ 'migration', '115_auth_identity_legacy_external_backfill'
+ )
+FROM (
+ SELECT
+ uei.id,
+ uei.user_id,
+ BTRIM(uei.provider_user_id) AS provider_user_id,
+ BTRIM(uei.provider_union_id) AS provider_union_id,
+ BTRIM(uei.provider_username) AS provider_username,
+ BTRIM(uei.display_name) AS display_name,
+ COALESCE(NULLIF(BTRIM(COALESCE(uei.metadata, '')), '')::jsonb, '{}'::jsonb) AS metadata_json,
+ uei.created_at,
+ uei.updated_at
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
+ AND BTRIM(COALESCE(uei.provider_union_id, '')) <> ''
+) AS legacy
+ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
+$sql$;
+
+ EXECUTE $sql$
+INSERT INTO auth_identity_channels (
+ identity_id,
+ provider_type,
+ provider_key,
+ channel,
+ channel_app_id,
+ channel_subject,
+ metadata
+)
+SELECT
+ ai.id,
+ 'wechat',
+ 'wechat-main',
+ legacy.channel,
+ legacy.channel_app_id,
+ legacy.provider_user_id,
+ legacy.metadata_json || jsonb_build_object(
+ 'openid', legacy.provider_user_id,
+ 'unionid', legacy.provider_union_id,
+ 'migration', '115_auth_identity_legacy_external_backfill'
+ )
+FROM (
+ SELECT
+ uei.user_id,
+ BTRIM(uei.provider_user_id) AS provider_user_id,
+ BTRIM(uei.provider_union_id) AS provider_union_id,
+ BTRIM(COALESCE(meta.metadata_json ->> 'channel', '')) AS channel,
+ BTRIM(COALESCE(meta.metadata_json ->> 'channel_app_id', meta.metadata_json ->> 'appid', meta.metadata_json ->> 'app_id', '')) AS channel_app_id,
+ meta.metadata_json
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ CROSS JOIN LATERAL (
+ SELECT COALESCE(NULLIF(BTRIM(COALESCE(uei.metadata, '')), '')::jsonb, '{}'::jsonb) AS metadata_json
+ ) AS meta
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
+ AND BTRIM(COALESCE(uei.provider_union_id, '')) <> ''
+) AS legacy
+JOIN auth_identities AS ai
+ ON ai.user_id = legacy.user_id
+ AND ai.provider_type = 'wechat'
+ AND ai.provider_key = 'wechat-main'
+ AND ai.provider_subject = legacy.provider_union_id
+WHERE legacy.channel <> ''
+ AND legacy.channel_app_id <> ''
+ AND legacy.provider_user_id <> ''
+ON CONFLICT DO NOTHING;
+$sql$;
+
+ EXECUTE $sql$
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'wechat_openid_only_requires_remediation',
+ 'legacy_external_identity:' || legacy.id::text,
+ legacy.metadata_json || jsonb_build_object(
+ 'legacy_identity_id', legacy.id,
+ 'user_id', legacy.user_id,
+ 'openid', legacy.provider_user_id,
+ 'reason', 'legacy user_external_identities row only has openid and cannot be canonicalized offline',
+ 'migration', '115_auth_identity_legacy_external_backfill'
+ )
+FROM (
+ SELECT
+ uei.id,
+ uei.user_id,
+ BTRIM(uei.provider_user_id) AS provider_user_id,
+ COALESCE(NULLIF(BTRIM(COALESCE(uei.metadata, '')), '')::jsonb, '{}'::jsonb) AS metadata_json
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
+ AND BTRIM(COALESCE(uei.provider_user_id, '')) <> ''
+ AND BTRIM(COALESCE(uei.provider_union_id, '')) = ''
+) AS legacy
+ON CONFLICT (report_type, report_key) DO NOTHING;
+$sql$;
+END $$;
+
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'wechat_openid_only_requires_remediation',
+ 'synthetic_auth_identity:' || ai.id::text,
+ COALESCE(ai.metadata, '{}'::jsonb) || jsonb_build_object(
+ 'auth_identity_id', ai.id,
+ 'user_id', ai.user_id,
+ 'provider_subject', ai.provider_subject,
+ 'reason', 'synthetic wechat auth identity still lacks unionid metadata and needs remediation',
+ 'migration', '115_auth_identity_legacy_external_backfill'
+ )
+FROM auth_identities AS ai
+WHERE ai.provider_type = 'wechat'
+ AND COALESCE(ai.metadata ->> 'backfill_source', '') = 'synthetic_email'
+ AND BTRIM(COALESCE(ai.metadata ->> 'unionid', '')) = ''
+ON CONFLICT (report_type, report_key) DO NOTHING;
--
GitLab
From 16be82b9599a47760e9e07b75675f7a0c6db11dd Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 00:14:05 +0800
Subject: [PATCH 081/261] fix payment visible methods and resume recovery
---
backend/internal/payment/provider/wxpay.go | 6 +-
.../internal/payment/provider/wxpay_test.go | 31 ++++++++
.../internal/service/payment_config_limits.go | 46 ++++++++++++
.../service/payment_config_limits_test.go | 71 +++++++++++++++++++
frontend/src/views/user/PaymentResultView.vue | 38 ++++++----
.../user/__tests__/PaymentResultView.spec.ts | 24 +++++++
6 files changed, 199 insertions(+), 17 deletions(-)
diff --git a/backend/internal/payment/provider/wxpay.go b/backend/internal/payment/provider/wxpay.go
index 84b324a8..47569ee3 100644
--- a/backend/internal/payment/provider/wxpay.go
+++ b/backend/internal/payment/provider/wxpay.go
@@ -310,12 +310,14 @@ func buildWxpayResultURL(returnURL string, req payment.CreatePaymentRequest) (st
return "", fmt.Errorf("return URL must be an absolute http(s) URL")
}
- values := url.Values{}
+ values := u.Query()
values.Set("out_trade_no", strings.TrimSpace(req.OrderID))
if paymentType := strings.TrimSpace(req.PaymentType); paymentType != "" {
values.Set("payment_type", paymentType)
}
- u.Path = wxpayResultPath
+ if strings.TrimSpace(u.Path) == "" {
+ u.Path = wxpayResultPath
+ }
u.RawPath = ""
u.RawQuery = values.Encode()
u.Fragment = ""
diff --git a/backend/internal/payment/provider/wxpay_test.go b/backend/internal/payment/provider/wxpay_test.go
index 5074c545..8489e261 100644
--- a/backend/internal/payment/provider/wxpay_test.go
+++ b/backend/internal/payment/provider/wxpay_test.go
@@ -4,6 +4,7 @@ package provider
import (
"context"
+ "net/url"
"strings"
"testing"
@@ -263,6 +264,36 @@ func TestNewWxpay(t *testing.T) {
}
}
+func TestBuildWxpayResultURLPreservesResumeToken(t *testing.T) {
+ t.Parallel()
+
+ resultURL, err := buildWxpayResultURL("https://app.example.com/payment/result?order_id=42&resume_token=resume-42&status=success", payment.CreatePaymentRequest{
+ OrderID: "sub2_42",
+ PaymentType: payment.TypeWxpay,
+ })
+ if err != nil {
+ t.Fatalf("buildWxpayResultURL returned error: %v", err)
+ }
+
+ parsed, err := url.Parse(resultURL)
+ if err != nil {
+ t.Fatalf("url.Parse returned error: %v", err)
+ }
+ query := parsed.Query()
+ if parsed.Path != wxpayResultPath {
+ t.Fatalf("path = %q, want %q", parsed.Path, wxpayResultPath)
+ }
+ if query.Get("resume_token") != "resume-42" {
+ t.Fatalf("resume_token = %q, want %q", query.Get("resume_token"), "resume-42")
+ }
+ if query.Get("order_id") != "42" {
+ t.Fatalf("order_id = %q, want %q", query.Get("order_id"), "42")
+ }
+ if query.Get("out_trade_no") != "sub2_42" {
+ t.Fatalf("out_trade_no = %q, want %q", query.Get("out_trade_no"), "sub2_42")
+ }
+}
+
func TestResolveWxpayJSAPIAppID(t *testing.T) {
t.Parallel()
diff --git a/backend/internal/service/payment_config_limits.go b/backend/internal/service/payment_config_limits.go
index 56905278..f30b119a 100644
--- a/backend/internal/service/payment_config_limits.go
+++ b/backend/internal/service/payment_config_limits.go
@@ -20,6 +20,18 @@ func (s *PaymentConfigService) GetAvailableMethodLimits(ctx context.Context) (*M
return nil, fmt.Errorf("query provider instances: %w", err)
}
typeInstances := pcGroupByPaymentType(instances)
+ if s.settingRepo != nil {
+ vals, err := s.settingRepo.GetMultiple(ctx, []string{
+ SettingPaymentVisibleMethodAlipayEnabled,
+ SettingPaymentVisibleMethodAlipaySource,
+ SettingPaymentVisibleMethodWxpayEnabled,
+ SettingPaymentVisibleMethodWxpaySource,
+ })
+ if err != nil {
+ return nil, fmt.Errorf("query visible method settings: %w", err)
+ }
+ typeInstances = pcApplyVisibleMethodRouting(typeInstances, vals, buildVisibleMethodSourceAvailability(instances))
+ }
resp := &MethodLimitsResponse{
Methods: make(map[string]MethodLimits, len(typeInstances)),
}
@@ -31,6 +43,40 @@ func (s *PaymentConfigService) GetAvailableMethodLimits(ctx context.Context) (*M
return resp, nil
}
+func pcApplyVisibleMethodRouting(typeInstances map[string][]*dbent.PaymentProviderInstance, vals map[string]string, available map[string]bool) map[string][]*dbent.PaymentProviderInstance {
+ if len(typeInstances) == 0 {
+ return typeInstances
+ }
+
+ filtered := make(map[string][]*dbent.PaymentProviderInstance, len(typeInstances))
+ for paymentType, instances := range typeInstances {
+ visibleMethod := NormalizeVisibleMethod(paymentType)
+ switch visibleMethod {
+ case payment.TypeAlipay, payment.TypeWxpay:
+ if !visibleMethodShouldBeExposed(visibleMethod, vals, available) {
+ continue
+ }
+ targetProviderKey, ok := VisibleMethodProviderKeyForSource(visibleMethod, vals[visibleMethodSourceSettingKey(visibleMethod)])
+ if !ok {
+ continue
+ }
+ matching := make([]*dbent.PaymentProviderInstance, 0, len(instances))
+ for _, inst := range instances {
+ if inst.ProviderKey == targetProviderKey {
+ matching = append(matching, inst)
+ }
+ }
+ if len(matching) == 0 {
+ continue
+ }
+ filtered[paymentType] = matching
+ default:
+ filtered[paymentType] = instances
+ }
+ }
+ return filtered
+}
+
// GetMethodLimits returns per-payment-type limits from enabled provider instances.
func (s *PaymentConfigService) GetMethodLimits(ctx context.Context, types []string) ([]MethodLimits, error) {
instances, err := s.entClient.PaymentProviderInstance.Query().
diff --git a/backend/internal/service/payment_config_limits_test.go b/backend/internal/service/payment_config_limits_test.go
index 73ad66ef..4a9d663d 100644
--- a/backend/internal/service/payment_config_limits_test.go
+++ b/backend/internal/service/payment_config_limits_test.go
@@ -1,6 +1,7 @@
package service
import (
+ "context"
"testing"
dbent "github.com/Wei-Shaw/sub2api/ent"
@@ -299,3 +300,73 @@ func TestPcInstanceTypeLimits(t *testing.T) {
}
})
}
+
+func TestGetAvailableMethodLimitsRespectsVisibleMethodRouting(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+
+ _, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeAlipay).
+ SetName("Official Alipay").
+ SetConfig("{}").
+ SetSupportedTypes("alipay").
+ SetLimits(`{"alipay":{"singleMin":10,"singleMax":100}}`).
+ SetEnabled(true).
+ Save(ctx)
+ if err != nil {
+ t.Fatalf("create official alipay instance: %v", err)
+ }
+ _, err = client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeEasyPay).
+ SetName("EasyPay Alipay").
+ SetConfig("{}").
+ SetSupportedTypes("alipay").
+ SetLimits(`{"alipay":{"singleMin":20,"singleMax":200}}`).
+ SetEnabled(true).
+ Save(ctx)
+ if err != nil {
+ t.Fatalf("create easypay alipay instance: %v", err)
+ }
+ _, err = client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeWxpay).
+ SetName("Official WeChat").
+ SetConfig("{}").
+ SetSupportedTypes("wxpay").
+ SetLimits(`{"wxpay":{"singleMin":30,"singleMax":300}}`).
+ SetEnabled(true).
+ Save(ctx)
+ if err != nil {
+ t.Fatalf("create official wxpay instance: %v", err)
+ }
+
+ svc := &PaymentConfigService{
+ entClient: client,
+ settingRepo: &paymentConfigSettingRepoStub{
+ values: map[string]string{
+ SettingPaymentVisibleMethodAlipayEnabled: "true",
+ SettingPaymentVisibleMethodAlipaySource: VisibleMethodSourceEasyPayAlipay,
+ SettingPaymentVisibleMethodWxpayEnabled: "false",
+ SettingPaymentVisibleMethodWxpaySource: VisibleMethodSourceOfficialWechat,
+ },
+ },
+ }
+
+ resp, err := svc.GetAvailableMethodLimits(ctx)
+ if err != nil {
+ t.Fatalf("GetAvailableMethodLimits returned error: %v", err)
+ }
+
+ alipayLimits, ok := resp.Methods[payment.TypeAlipay]
+ if !ok {
+ t.Fatalf("expected visible alipay limits, got %v", resp.Methods)
+ }
+ if alipayLimits.SingleMin != 20 || alipayLimits.SingleMax != 200 {
+ t.Fatalf("alipay limits = %+v, want easypay-only min=20 max=200", alipayLimits)
+ }
+ if _, ok := resp.Methods[payment.TypeWxpay]; ok {
+ t.Fatalf("wxpay should be hidden when visible method is disabled, got %v", resp.Methods[payment.TypeWxpay])
+ }
+ if resp.GlobalMin != 20 || resp.GlobalMax != 200 {
+ t.Fatalf("global range = (%v, %v), want (20, 200)", resp.GlobalMin, resp.GlobalMax)
+ }
+}
diff --git a/frontend/src/views/user/PaymentResultView.vue b/frontend/src/views/user/PaymentResultView.vue
index 53bbb550..e1bcbbe5 100644
--- a/frontend/src/views/user/PaymentResultView.vue
+++ b/frontend/src/views/user/PaymentResultView.vue
@@ -142,10 +142,11 @@ onMounted(async () => {
const resumeToken = typeof route.query.resume_token === 'string'
? route.query.resume_token
: ''
- let orderId = Number(route.query.order_id) || 0
+ const routeOrderId = Number(route.query.order_id) || 0
const outTradeNo = String(route.query.out_trade_no || '')
+ let orderId = 0
- if (!orderId && resumeToken && typeof window !== 'undefined') {
+ if (resumeToken && typeof window !== 'undefined') {
const restored = readPaymentRecoverySnapshot(
window.localStorage.getItem(PAYMENT_RECOVERY_STORAGE_KEY),
{ resumeToken },
@@ -155,17 +156,31 @@ onMounted(async () => {
}
}
- if (!order.value && !orderId && resumeToken) {
+ if (!order.value && resumeToken && orderId) {
+ try {
+ order.value = await paymentStore.pollOrderStatus(orderId)
+ } catch (_err: unknown) {
+ // Fall through to signed resume-token recovery below.
+ }
+ }
+
+ if (!order.value && resumeToken) {
try {
const result = await paymentAPI.resolveOrderPublicByResumeToken(resumeToken)
order.value = result.data
- orderId = result.data.id
+ if (!orderId) {
+ orderId = result.data.id
+ }
} catch (_err: unknown) {
- // Resume token recovery failed, continue to legacy fallback paths.
+ // Resume token recovery failed; do not trust legacy public out_trade_no fallback.
}
}
- if (!order.value && orderId) {
+ if (!resumeToken) {
+ orderId = routeOrderId
+ }
+
+ if (!order.value && !resumeToken && orderId) {
try {
order.value = await paymentStore.pollOrderStatus(orderId)
} catch (_err: unknown) {
@@ -173,7 +188,8 @@ onMounted(async () => {
}
}
- if (!order.value && outTradeNo) {
+ const hasLegacyFallbackContext = Boolean(route.query.trade_status || route.query.money || route.query.type)
+ if (!order.value && !resumeToken && !orderId && outTradeNo && hasLegacyFallbackContext) {
returnInfo.value = {
outTradeNo,
money: String(route.query.money || ''),
@@ -191,14 +207,6 @@ onMounted(async () => {
} catch (_e: unknown) { /* fall through */ }
}
}
-
- if (!order.value && orderId) {
- try {
- order.value = await paymentStore.pollOrderStatus(orderId)
- } catch (_err: unknown) {
- // Order lookup failed, will show returnInfo fallback.
- }
- }
loading.value = false
})
diff --git a/frontend/src/views/user/__tests__/PaymentResultView.spec.ts b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts
index b1caa526..bfc044a7 100644
--- a/frontend/src/views/user/__tests__/PaymentResultView.spec.ts
+++ b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts
@@ -76,6 +76,7 @@ describe('PaymentResultView', () => {
it('restores order id from a matching resume token and does not trust query success flags', async () => {
routeState.query = {
resume_token: 'resume-42',
+ order_id: '999',
status: 'success',
}
window.localStorage.setItem(PAYMENT_RECOVERY_STORAGE_KEY, JSON.stringify({
@@ -110,6 +111,29 @@ describe('PaymentResultView', () => {
expect(wrapper.text()).not.toContain('payment.result.success')
})
+ it('does not fall back to public out_trade_no verification when resume_token recovery fails', async () => {
+ routeState.query = {
+ resume_token: 'resume-fail',
+ out_trade_no: 'legacy-should-not-run',
+ trade_status: 'TRADE_SUCCESS',
+ }
+ resolveOrderPublicByResumeToken.mockRejectedValueOnce(new Error('resume failed'))
+
+ mount(PaymentResultView, {
+ global: {
+ stubs: {
+ OrderStatusBadge: true,
+ },
+ },
+ })
+
+ await flushPromises()
+
+ expect(resolveOrderPublicByResumeToken).toHaveBeenCalledWith('resume-fail')
+ expect(verifyOrderPublic).not.toHaveBeenCalled()
+ expect(verifyOrder).not.toHaveBeenCalled()
+ })
+
it('keeps legacy out_trade_no verification as a fallback when no order context is available', async () => {
routeState.query = {
out_trade_no: 'legacy-123',
--
GitLab
From b79052aaf2c5dfe201cc00c70cd37809961a5f2d Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 00:16:06 +0800
Subject: [PATCH 082/261] Decouple email sync tests from local stubs
---
.../admin_service_email_identity_sync_test.go | 130 +++++++++++++++---
.../user_service_email_identity_sync_test.go | 43 +-----
2 files changed, 114 insertions(+), 59 deletions(-)
diff --git a/backend/internal/service/admin_service_email_identity_sync_test.go b/backend/internal/service/admin_service_email_identity_sync_test.go
index 3f3d867c..d555d609 100644
--- a/backend/internal/service/admin_service_email_identity_sync_test.go
+++ b/backend/internal/service/admin_service_email_identity_sync_test.go
@@ -4,9 +4,11 @@ package service
import (
"context"
+ "fmt"
"testing"
"time"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
@@ -21,36 +23,122 @@ type replaceEmailCall struct {
newEmail string
}
-type emailSyncUserRepoStub struct {
- *userRepoStub
+type emailSyncRepoStub struct {
+ user *User
+ nextID int64
+ updateCalls int
+ created []*User
+ updated []*User
ensureCalls []ensureEmailCall
replaceCalls []replaceEmailCall
}
-func (s *emailSyncUserRepoStub) EnsureEmailAuthIdentity(_ context.Context, userID int64, email string) error {
- s.ensureCalls = append(s.ensureCalls, ensureEmailCall{userID: userID, email: email})
+func (s *emailSyncRepoStub) Create(_ context.Context, user *User) error {
+ if s.nextID != 0 && user.ID == 0 {
+ user.ID = s.nextID
+ }
+ s.created = append(s.created, user)
+ s.user = user
return nil
}
-func (s *emailSyncUserRepoStub) ReplaceEmailAuthIdentity(_ context.Context, userID int64, oldEmail, newEmail string) error {
- s.replaceCalls = append(s.replaceCalls, replaceEmailCall{
- userID: userID,
- oldEmail: oldEmail,
- newEmail: newEmail,
- })
+func (s *emailSyncRepoStub) GetByID(_ context.Context, _ int64) (*User, error) {
+ if s.user == nil {
+ return nil, ErrUserNotFound
+ }
+ cloned := *s.user
+ return &cloned, nil
+}
+
+func (s *emailSyncRepoStub) GetByEmail(_ context.Context, _ string) (*User, error) {
+ return nil, ErrUserNotFound
+}
+
+func (s *emailSyncRepoStub) GetFirstAdmin(context.Context) (*User, error) {
+ return nil, fmt.Errorf("unexpected GetFirstAdmin call")
+}
+
+func (s *emailSyncRepoStub) Update(_ context.Context, user *User) error {
+ s.updateCalls++
+ s.updated = append(s.updated, user)
+ s.user = user
return nil
}
-func (s *emailSyncUserRepoStub) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
+func (s *emailSyncRepoStub) Delete(context.Context, int64) error { return nil }
+
+func (s *emailSyncRepoStub) GetUserAvatar(context.Context, int64) (*UserAvatar, error) {
+ return nil, fmt.Errorf("unexpected GetUserAvatar call")
+}
+
+func (s *emailSyncRepoStub) UpsertUserAvatar(context.Context, int64, UpsertUserAvatarInput) (*UserAvatar, error) {
+ return nil, fmt.Errorf("unexpected UpsertUserAvatar call")
+}
+
+func (s *emailSyncRepoStub) DeleteUserAvatar(context.Context, int64) error {
+ return fmt.Errorf("unexpected DeleteUserAvatar call")
+}
+
+func (s *emailSyncRepoStub) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
+ return nil, nil, fmt.Errorf("unexpected List call")
+}
+
+func (s *emailSyncRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, UserListFilters) ([]User, *pagination.PaginationResult, error) {
+ return nil, nil, fmt.Errorf("unexpected ListWithFilters call")
+}
+
+func (s *emailSyncRepoStub) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
return map[int64]*time.Time{}, nil
}
-func (s *emailSyncUserRepoStub) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) {
+func (s *emailSyncRepoStub) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) {
return nil, nil
}
+func (s *emailSyncRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil }
+
+func (s *emailSyncRepoStub) DeductBalance(context.Context, int64, float64) error { return nil }
+
+func (s *emailSyncRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil }
+
+func (s *emailSyncRepoStub) ExistsByEmail(context.Context, string) (bool, error) { return false, nil }
+
+func (s *emailSyncRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
+ return 0, nil
+}
+
+func (s *emailSyncRepoStub) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil }
+
+func (s *emailSyncRepoStub) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
+ return nil
+}
+
+func (s *emailSyncRepoStub) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) {
+ return nil, nil
+}
+
+func (s *emailSyncRepoStub) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
+
+func (s *emailSyncRepoStub) EnableTotp(context.Context, int64) error { return nil }
+
+func (s *emailSyncRepoStub) DisableTotp(context.Context, int64) error { return nil }
+
+func (s *emailSyncRepoStub) EnsureEmailAuthIdentity(_ context.Context, userID int64, email string) error {
+ s.ensureCalls = append(s.ensureCalls, ensureEmailCall{userID: userID, email: email})
+ return nil
+}
+
+func (s *emailSyncRepoStub) ReplaceEmailAuthIdentity(_ context.Context, userID int64, oldEmail, newEmail string) error {
+ s.replaceCalls = append(s.replaceCalls, replaceEmailCall{
+ userID: userID,
+ oldEmail: oldEmail,
+ newEmail: newEmail,
+ })
+ return nil
+}
+
func TestAdminService_CreateUser_EnsuresEmailAuthIdentity(t *testing.T) {
- repo := &emailSyncUserRepoStub{userRepoStub: &userRepoStub{nextID: 55}}
+ repo := &emailSyncRepoStub{nextID: 55}
svc := &adminServiceImpl{userRepo: repo}
user, err := svc.CreateUser(context.Background(), &CreateUserInput{
@@ -67,15 +155,13 @@ func TestAdminService_CreateUser_EnsuresEmailAuthIdentity(t *testing.T) {
}
func TestAdminService_UpdateUser_ReplacesEmailAuthIdentity(t *testing.T) {
- repo := &emailSyncUserRepoStub{
- userRepoStub: &userRepoStub{
- user: &User{
- ID: 91,
- Email: "before@example.com",
- Role: RoleUser,
- Status: StatusActive,
- Concurrency: 3,
- },
+ repo := &emailSyncRepoStub{
+ user: &User{
+ ID: 91,
+ Email: "before@example.com",
+ Role: RoleUser,
+ Status: StatusActive,
+ Concurrency: 3,
},
}
svc := &adminServiceImpl{userRepo: repo}
diff --git a/backend/internal/service/user_service_email_identity_sync_test.go b/backend/internal/service/user_service_email_identity_sync_test.go
index 3950df8b..8109b368 100644
--- a/backend/internal/service/user_service_email_identity_sync_test.go
+++ b/backend/internal/service/user_service_email_identity_sync_test.go
@@ -5,48 +5,17 @@ package service
import (
"context"
"testing"
- "time"
"github.com/stretchr/testify/require"
)
-type emailSyncMockUserRepo struct {
- *mockUserRepo
- ensureCalls []ensureEmailCall
- replaceCalls []replaceEmailCall
-}
-
-func (m *emailSyncMockUserRepo) EnsureEmailAuthIdentity(_ context.Context, userID int64, email string) error {
- m.ensureCalls = append(m.ensureCalls, ensureEmailCall{userID: userID, email: email})
- return nil
-}
-
-func (m *emailSyncMockUserRepo) ReplaceEmailAuthIdentity(_ context.Context, userID int64, oldEmail, newEmail string) error {
- m.replaceCalls = append(m.replaceCalls, replaceEmailCall{
- userID: userID,
- oldEmail: oldEmail,
- newEmail: newEmail,
- })
- return nil
-}
-
-func (m *emailSyncMockUserRepo) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
- return map[int64]*time.Time{}, nil
-}
-
-func (m *emailSyncMockUserRepo) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) {
- return nil, nil
-}
-
func TestUpdateProfile_ReplacesEmailAuthIdentityWhenEmailChanges(t *testing.T) {
- repo := &emailSyncMockUserRepo{
- mockUserRepo: &mockUserRepo{
- getByIDUser: &User{
- ID: 19,
- Email: "profile-before@example.com",
- Username: "tester",
- Concurrency: 2,
- },
+ repo := &emailSyncRepoStub{
+ user: &User{
+ ID: 19,
+ Email: "profile-before@example.com",
+ Username: "tester",
+ Concurrency: 2,
},
}
svc := NewUserService(repo, nil, nil, nil)
--
GitLab
From beeab54ae3b1d44f291822f99d064d50bef59cd7 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 00:17:48 +0800
Subject: [PATCH 083/261] Implement latest-used user repo queries
---
backend/internal/repository/user_repo.go | 46 ++++++++++++++++++++++++
1 file changed, 46 insertions(+)
diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go
index b2190b68..7378611d 100644
--- a/backend/internal/repository/user_repo.go
+++ b/backend/internal/repository/user_repo.go
@@ -19,6 +19,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/lib/pq"
entsql "entgo.io/ent/dialect/sql"
)
@@ -298,6 +299,51 @@ func normalizeEmailAuthIdentitySubject(email string) string {
return normalized
}
+func (r *userRepository) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) {
+ result := make(map[int64]*time.Time, len(userIDs))
+ if len(userIDs) == 0 {
+ return result, nil
+ }
+ if r.sql == nil {
+ return nil, fmt.Errorf("sql executor is not configured")
+ }
+
+ rows, err := r.sql.QueryContext(ctx, `
+ SELECT user_id, MAX(created_at) AS last_used_at
+ FROM usage_logs
+ WHERE user_id = ANY($1)
+ GROUP BY user_id
+ `, pq.Array(userIDs))
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = rows.Close() }()
+
+ for rows.Next() {
+ var (
+ userID int64
+ lastUsedAt time.Time
+ )
+ if err := rows.Scan(&userID, &lastUsedAt); err != nil {
+ return nil, err
+ }
+ ts := lastUsedAt.UTC()
+ result[userID] = &ts
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return result, nil
+}
+
+func (r *userRepository) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) {
+ latestByUserID, err := r.GetLatestUsedAtByUserIDs(ctx, []int64{userID})
+ if err != nil {
+ return nil, err
+ }
+ return latestByUserID[userID], nil
+}
+
func (r *userRepository) Delete(ctx context.Context, id int64) error {
affected, err := r.client.User.Delete().Where(dbuser.IDEQ(id)).Exec(ctx)
if err != nil {
--
GitLab
From 7da512406790011d16aff30fd9df5d574deab558 Mon Sep 17 00:00:00 2001
From: erio
Date: Tue, 21 Apr 2026 00:21:29 +0800
Subject: [PATCH 084/261] feat(channel-monitor): add feature switch settings +
fix extra_models save
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Settings:
- New "功能开关" tab between 通用设置 and 安全与认证
- ChannelMonitorEnabled toggle: runner skips scheduling when false,
user-facing list returns empty
- ChannelMonitorDefaultIntervalSeconds (15-3600): pre-fills interval
when creating a new monitor; each monitor can still override
Bug fix:
- ModelTagInput now commits pending input on blur, not just Enter/Tab.
Previously clicking "save" with an un-Enter'd extra model would drop
the value (DB stored extra_models=[] even when user typed entries).
Backend:
- domain_constants: SettingKeyChannelMonitor{Enabled,DefaultIntervalSeconds}
- SettingService.GetChannelMonitorRuntime: lightweight getter used by
runner tick + user handler per-request (fail-open on DB error)
- Runner tickDueChecks: bails early when feature disabled
- ChannelMonitorUserHandler: checks feature flag before serving
- Comment on runner doc: scheduler state is implicit (every tick re-reads
ListEnabled from DB), so CRUD ops on monitors self-maintain the schedule
Bump VERSION to 0.1.114.25
---
backend/cmd/server/wire_gen.go | 4 +-
.../internal/handler/admin/setting_handler.go | 28 +++++++
.../handler/channel_monitor_user_handler.go | 29 ++++++-
backend/internal/handler/dto/settings.go | 7 ++
backend/internal/handler/setting_handler.go | 3 +
.../service/channel_monitor_runner.go | 22 ++++--
backend/internal/service/domain_constants.go | 12 +++
backend/internal/service/setting_service.go | 76 +++++++++++++++++++
backend/internal/service/settings_view.go | 8 ++
backend/internal/service/wire.go | 5 +-
frontend/src/api/admin/settings.ts | 8 ++
.../admin/channel/ModelTagInput.vue | 1 +
.../admin/monitor/MonitorFormDialog.vue | 13 +++-
frontend/src/i18n/locales/en.ts | 11 +++
frontend/src/i18n/locales/zh.ts | 11 +++
frontend/src/stores/app.ts | 2 +
frontend/src/types/index.ts | 2 +
frontend/src/views/admin/SettingsView.vue | 55 ++++++++++++++
18 files changed, 283 insertions(+), 14 deletions(-)
diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go
index 8e367e81..754f814a 100644
--- a/backend/cmd/server/wire_gen.go
+++ b/backend/cmd/server/wire_gen.go
@@ -217,8 +217,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
channelMonitorRepository := repository.NewChannelMonitorRepository(client, sqlDB)
channelMonitorService := service.ProvideChannelMonitorService(channelMonitorRepository, secretEncryptor)
channelMonitorHandler := admin.NewChannelMonitorHandler(channelMonitorService)
- channelMonitorUserHandler := handler.NewChannelMonitorUserHandler(channelMonitorService)
- channelMonitorRunner := service.ProvideChannelMonitorRunner(channelMonitorService)
+ channelMonitorUserHandler := handler.NewChannelMonitorUserHandler(channelMonitorService, settingService)
+ channelMonitorRunner := service.ProvideChannelMonitorRunner(channelMonitorService, settingService)
_ = channelMonitorRunner
registry := payment.ProvideRegistry()
encryptionKey, err := payment.ProvideEncryptionKey(configConfig)
diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go
index a882d1a1..40c944eb 100644
--- a/backend/internal/handler/admin/setting_handler.go
+++ b/backend/internal/handler/admin/setting_handler.go
@@ -235,6 +235,9 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
PaymentCancelRateLimitWindow: paymentCfg.CancelRateLimitWindow,
PaymentCancelRateLimitUnit: paymentCfg.CancelRateLimitUnit,
PaymentCancelRateLimitMode: paymentCfg.CancelRateLimitMode,
+
+ ChannelMonitorEnabled: settings.ChannelMonitorEnabled,
+ ChannelMonitorDefaultIntervalSeconds: settings.ChannelMonitorDefaultIntervalSeconds,
}
response.Success(c, systemSettingsResponseData(payload, authSourceDefaults))
}
@@ -425,6 +428,10 @@ type UpdateSettingsRequest struct {
PaymentCancelRateLimitWindow *int `json:"payment_cancel_rate_limit_window"`
PaymentCancelRateLimitUnit *string `json:"payment_cancel_rate_limit_unit"`
PaymentCancelRateLimitMode *string `json:"payment_cancel_rate_limit_window_mode"`
+
+ // Channel Monitor feature switch
+ ChannelMonitorEnabled *bool `json:"channel_monitor_enabled"`
+ ChannelMonitorDefaultIntervalSeconds *int `json:"channel_monitor_default_interval_seconds"`
}
// UpdateSettings 更新系统设置
@@ -1219,6 +1226,18 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
return previousSettings.AccountQuotaNotifyEmails
}(),
+ ChannelMonitorEnabled: func() bool {
+ if req.ChannelMonitorEnabled != nil {
+ return *req.ChannelMonitorEnabled
+ }
+ return previousSettings.ChannelMonitorEnabled
+ }(),
+ ChannelMonitorDefaultIntervalSeconds: func() int {
+ if req.ChannelMonitorDefaultIntervalSeconds != nil {
+ return *req.ChannelMonitorDefaultIntervalSeconds
+ }
+ return previousSettings.ChannelMonitorDefaultIntervalSeconds
+ }(),
}
authSourceDefaults := &service.AuthSourceDefaultSettings{
@@ -1449,6 +1468,9 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
PaymentCancelRateLimitWindow: updatedPaymentCfg.CancelRateLimitWindow,
PaymentCancelRateLimitUnit: updatedPaymentCfg.CancelRateLimitUnit,
PaymentCancelRateLimitMode: updatedPaymentCfg.CancelRateLimitMode,
+
+ ChannelMonitorEnabled: updatedSettings.ChannelMonitorEnabled,
+ ChannelMonitorDefaultIntervalSeconds: updatedSettings.ChannelMonitorDefaultIntervalSeconds,
}
response.Success(c, systemSettingsResponseData(payload, updatedAuthSourceDefaults))
}
@@ -1805,6 +1827,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if !equalNotifyEmailEntries(before.AccountQuotaNotifyEmails, after.AccountQuotaNotifyEmails) {
changed = append(changed, "account_quota_notify_emails")
}
+ if before.ChannelMonitorEnabled != after.ChannelMonitorEnabled {
+ changed = append(changed, "channel_monitor_enabled")
+ }
+ if before.ChannelMonitorDefaultIntervalSeconds != after.ChannelMonitorDefaultIntervalSeconds {
+ changed = append(changed, "channel_monitor_default_interval_seconds")
+ }
changed = appendAuthSourceDefaultChanges(changed, beforeAuthSourceDefaults, afterAuthSourceDefaults)
return changed
}
diff --git a/backend/internal/handler/channel_monitor_user_handler.go b/backend/internal/handler/channel_monitor_user_handler.go
index 6a513dc1..cc36b334 100644
--- a/backend/internal/handler/channel_monitor_user_handler.go
+++ b/backend/internal/handler/channel_monitor_user_handler.go
@@ -14,11 +14,28 @@ import (
// ChannelMonitorUserHandler 渠道监控用户只读 handler。
type ChannelMonitorUserHandler struct {
monitorService *service.ChannelMonitorService
+ settingService *service.SettingService
}
// NewChannelMonitorUserHandler 创建 handler。
-func NewChannelMonitorUserHandler(monitorService *service.ChannelMonitorService) *ChannelMonitorUserHandler {
- return &ChannelMonitorUserHandler{monitorService: monitorService}
+// settingService 用于每次请求前读取功能开关;关闭时 List/GetStatus 直接返回空/404。
+func NewChannelMonitorUserHandler(
+ monitorService *service.ChannelMonitorService,
+ settingService *service.SettingService,
+) *ChannelMonitorUserHandler {
+ return &ChannelMonitorUserHandler{
+ monitorService: monitorService,
+ settingService: settingService,
+ }
+}
+
+// featureEnabled 返回当前渠道监控功能是否开启。
+// settingService 为 nil(测试场景)视为启用。
+func (h *ChannelMonitorUserHandler) featureEnabled(c *gin.Context) bool {
+ if h.settingService == nil {
+ return true
+ }
+ return h.settingService.GetChannelMonitorRuntime(c.Request.Context()).Enabled
}
// --- Response ---
@@ -123,6 +140,10 @@ func userMonitorDetailToResponse(d *service.UserMonitorDetail) *channelMonitorUs
// List GET /api/v1/channel-monitors
func (h *ChannelMonitorUserHandler) List(c *gin.Context) {
+ if !h.featureEnabled(c) {
+ response.Success(c, gin.H{"items": []channelMonitorUserListItem{}})
+ return
+ }
views, err := h.monitorService.ListUserView(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
@@ -137,6 +158,10 @@ func (h *ChannelMonitorUserHandler) List(c *gin.Context) {
// GetStatus GET /api/v1/channel-monitors/:id/status
func (h *ChannelMonitorUserHandler) GetStatus(c *gin.Context) {
+ if !h.featureEnabled(c) {
+ response.ErrorFrom(c, service.ErrChannelMonitorNotFound)
+ return
+ }
// 复用 admin.ParseChannelMonitorID 保持错误码与日志一致。
id, ok := admin.ParseChannelMonitorID(c)
if !ok {
diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go
index fc6a3f9e..9d9bb6c5 100644
--- a/backend/internal/handler/dto/settings.go
+++ b/backend/internal/handler/dto/settings.go
@@ -183,6 +183,10 @@ type SystemSettings struct {
BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"`
AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"`
AccountQuotaNotifyEmails []NotifyEmailEntry `json:"account_quota_notify_emails"`
+
+ // Channel Monitor feature switch
+ ChannelMonitorEnabled bool `json:"channel_monitor_enabled"`
+ ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"`
}
type DefaultSubscriptionSetting struct {
@@ -230,6 +234,9 @@ type PublicSettings struct {
AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"`
BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"`
+
+ ChannelMonitorEnabled bool `json:"channel_monitor_enabled"`
+ ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"`
}
// OverloadCooldownSettings 529过载冷却配置 DTO
diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go
index c0f5c28b..8d72206f 100644
--- a/backend/internal/handler/setting_handler.go
+++ b/backend/internal/handler/setting_handler.go
@@ -70,5 +70,8 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
AccountQuotaNotifyEnabled: settings.AccountQuotaNotifyEnabled,
BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold,
BalanceLowNotifyRechargeURL: settings.BalanceLowNotifyRechargeURL,
+
+ ChannelMonitorEnabled: settings.ChannelMonitorEnabled,
+ ChannelMonitorDefaultIntervalSeconds: settings.ChannelMonitorDefaultIntervalSeconds,
})
}
diff --git a/backend/internal/service/channel_monitor_runner.go b/backend/internal/service/channel_monitor_runner.go
index 377903d3..4655e6df 100644
--- a/backend/internal/service/channel_monitor_runner.go
+++ b/backend/internal/service/channel_monitor_runner.go
@@ -18,8 +18,13 @@ import (
// - Stop 时优雅关闭:池 drain + ticker.Stop + wg.Wait
//
// 不引入 cron 库;清理调度通过"每小时检查时间"实现,足够 MVP。
+//
+// 定时任务维护:删除/创建/编辑 monitor 无需显式 reload,每个 tick 都会重新查 DB
+// (ListEnabled + listDueForCheck),新 monitor 的 LastCheckedAt 为 nil 天然立即到期,
+// 被删除的 monitor 自然不再返回,interval 变化下次 tick 自动按新值判定。
type ChannelMonitorRunner struct {
- svc *ChannelMonitorService
+ svc *ChannelMonitorService
+ settingService *SettingService
pool pond.Pool
stopCh chan struct{}
@@ -37,11 +42,13 @@ type ChannelMonitorRunner struct {
}
// NewChannelMonitorRunner 构造调度器。Start 在 wire 中调用。
-func NewChannelMonitorRunner(svc *ChannelMonitorService) *ChannelMonitorRunner {
+// settingService 用于在每次 tick 前读取功能开关;传 nil 时视为总是启用(兼容测试)。
+func NewChannelMonitorRunner(svc *ChannelMonitorService, settingService *SettingService) *ChannelMonitorRunner {
return &ChannelMonitorRunner{
- svc: svc,
- stopCh: make(chan struct{}),
- inFlight: make(map[int64]struct{}),
+ svc: svc,
+ settingService: settingService,
+ stopCh: make(chan struct{}),
+ inFlight: make(map[int64]struct{}),
}
}
@@ -93,10 +100,15 @@ func (r *ChannelMonitorRunner) dueCheckLoop() {
// tickDueChecks 一次扫描:查询到期监控并逐个提交到池。
// 已在执行的 monitor 会被跳过(防止单次检测耗时 > interval 时重复调度)。
// 池满时使用 TrySubmit 跳过(不能阻塞 ticker),同时立即释放已占用的 inFlight 槽。
+// 当功能开关关闭时直接返回——管理员可以动态禁用模块,runner 不会拉取 DB。
func (r *ChannelMonitorRunner) tickDueChecks() {
ctx, cancel := context.WithTimeout(context.Background(), monitorListDueTimeout)
defer cancel()
+ if r.settingService != nil && !r.settingService.GetChannelMonitorRuntime(ctx).Enabled {
+ return
+ }
+
due, err := r.svc.listDueForCheck(ctx)
if err != nil {
slog.Warn("channel_monitor: list due failed", "error", err)
diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go
index 3c6888b8..ef2259ed 100644
--- a/backend/internal/service/domain_constants.go
+++ b/backend/internal/service/domain_constants.go
@@ -242,6 +242,18 @@ const (
// SettingKeyOpsRuntimeLogConfig stores JSON config for runtime log settings.
SettingKeyOpsRuntimeLogConfig = "ops_runtime_log_config"
+ // =========================
+ // Channel Monitor (渠道监控)
+ // =========================
+
+ // SettingKeyChannelMonitorEnabled is a DB-backed soft switch for the channel monitor feature.
+ // When false: runner skips scheduling and user-facing endpoints return an empty list.
+ SettingKeyChannelMonitorEnabled = "channel_monitor_enabled"
+
+ // SettingKeyChannelMonitorDefaultIntervalSeconds controls the default interval (seconds)
+ // pre-filled when creating a new channel monitor from the admin UI. Range: [15, 3600].
+ SettingKeyChannelMonitorDefaultIntervalSeconds = "channel_monitor_default_interval_seconds"
+
// =========================
// Overload Cooldown (529)
// =========================
diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go
index f2b644be..c901be84 100644
--- a/backend/internal/service/setting_service.go
+++ b/backend/internal/service/setting_service.go
@@ -450,6 +450,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyBalanceLowNotifyThreshold,
SettingKeyBalanceLowNotifyRechargeURL,
SettingKeyAccountQuotaNotifyEnabled,
+ SettingKeyChannelMonitorEnabled,
+ SettingKeyChannelMonitorDefaultIntervalSeconds,
}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
@@ -532,9 +534,67 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
AccountQuotaNotifyEnabled: settings[SettingKeyAccountQuotaNotifyEnabled] == "true",
BalanceLowNotifyThreshold: balanceLowNotifyThreshold,
BalanceLowNotifyRechargeURL: settings[SettingKeyBalanceLowNotifyRechargeURL],
+
+ ChannelMonitorEnabled: !isFalseSettingValue(settings[SettingKeyChannelMonitorEnabled]),
+ ChannelMonitorDefaultIntervalSeconds: parseChannelMonitorInterval(settings[SettingKeyChannelMonitorDefaultIntervalSeconds]),
}, nil
}
+// channelMonitorIntervalMin / channelMonitorIntervalMax bound the default interval
+// (mirrors the monitor-level constraint but lives here so setting_service stays decoupled).
+const (
+ channelMonitorIntervalMin = 15
+ channelMonitorIntervalMax = 3600
+ channelMonitorIntervalFallback = 60
+)
+
+// parseChannelMonitorInterval parses the stored string and clamps to [15, 3600].
+// Empty / invalid input falls back to channelMonitorIntervalFallback.
+func parseChannelMonitorInterval(raw string) int {
+ v, err := strconv.Atoi(strings.TrimSpace(raw))
+ if err != nil {
+ return channelMonitorIntervalFallback
+ }
+ return clampChannelMonitorInterval(v)
+}
+
+// clampChannelMonitorInterval clamps v to the allowed range. 0 means "not provided".
+func clampChannelMonitorInterval(v int) int {
+ if v <= 0 {
+ return 0
+ }
+ if v < channelMonitorIntervalMin {
+ return channelMonitorIntervalMin
+ }
+ if v > channelMonitorIntervalMax {
+ return channelMonitorIntervalMax
+ }
+ return v
+}
+
+// ChannelMonitorRuntime is the lightweight view of the channel monitor feature
+// consumed by the runner and user-facing handlers.
+type ChannelMonitorRuntime struct {
+ Enabled bool
+ DefaultIntervalSeconds int
+}
+
+// GetChannelMonitorRuntime reads the channel monitor feature flags directly from
+// the settings store. Fail-open: on error returns Enabled=true with the default interval.
+func (s *SettingService) GetChannelMonitorRuntime(ctx context.Context) ChannelMonitorRuntime {
+ vals, err := s.settingRepo.GetMultiple(ctx, []string{
+ SettingKeyChannelMonitorEnabled,
+ SettingKeyChannelMonitorDefaultIntervalSeconds,
+ })
+ if err != nil {
+ return ChannelMonitorRuntime{Enabled: true, DefaultIntervalSeconds: channelMonitorIntervalFallback}
+ }
+ return ChannelMonitorRuntime{
+ Enabled: !isFalseSettingValue(vals[SettingKeyChannelMonitorEnabled]),
+ DefaultIntervalSeconds: parseChannelMonitorInterval(vals[SettingKeyChannelMonitorDefaultIntervalSeconds]),
+ }
+}
+
// SetOnUpdateCallback sets a callback function to be called when settings are updated
// This is used for cache invalidation (e.g., HTML cache in frontend server)
func (s *SettingService) SetOnUpdateCallback(callback func()) {
@@ -1085,6 +1145,12 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
updates[SettingKeyOpsMetricsIntervalSeconds] = strconv.Itoa(settings.OpsMetricsIntervalSeconds)
}
+ // Channel monitor feature switch
+ updates[SettingKeyChannelMonitorEnabled] = strconv.FormatBool(settings.ChannelMonitorEnabled)
+ if v := clampChannelMonitorInterval(settings.ChannelMonitorDefaultIntervalSeconds); v > 0 {
+ updates[SettingKeyChannelMonitorDefaultIntervalSeconds] = strconv.Itoa(v)
+ }
+
// Claude Code version check
updates[SettingKeyMinClaudeCodeVersion] = settings.MinClaudeCodeVersion
updates[SettingKeyMaxClaudeCodeVersion] = settings.MaxClaudeCodeVersion
@@ -1630,6 +1696,10 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeyOpsQueryModeDefault: "auto",
SettingKeyOpsMetricsIntervalSeconds: "60",
+ // Channel monitor defaults (enabled, 60s)
+ SettingKeyChannelMonitorEnabled: "true",
+ SettingKeyChannelMonitorDefaultIntervalSeconds: "60",
+
// Claude Code version check (default: empty = disabled)
SettingKeyMinClaudeCodeVersion: "",
SettingKeyMaxClaudeCodeVersion: "",
@@ -1932,6 +2002,12 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
}
}
+ // Channel monitor feature (default: enabled, 60s)
+ result.ChannelMonitorEnabled = !isFalseSettingValue(settings[SettingKeyChannelMonitorEnabled])
+ result.ChannelMonitorDefaultIntervalSeconds = parseChannelMonitorInterval(
+ settings[SettingKeyChannelMonitorDefaultIntervalSeconds],
+ )
+
// Claude Code version check
result.MinClaudeCodeVersion = settings[SettingKeyMinClaudeCodeVersion]
result.MaxClaudeCodeVersion = settings[SettingKeyMaxClaudeCodeVersion]
diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go
index d2ef8fae..972faf80 100644
--- a/backend/internal/service/settings_view.go
+++ b/backend/internal/service/settings_view.go
@@ -125,6 +125,10 @@ type SystemSettings struct {
OpsQueryModeDefault string
OpsMetricsIntervalSeconds int
+ // Channel Monitor feature
+ ChannelMonitorEnabled bool `json:"channel_monitor_enabled"`
+ ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"`
+
// Claude Code version check
MinClaudeCodeVersion string
MaxClaudeCodeVersion string
@@ -209,6 +213,10 @@ type PublicSettings struct {
AccountQuotaNotifyEnabled bool
BalanceLowNotifyThreshold float64
BalanceLowNotifyRechargeURL string
+
+ // Channel Monitor feature
+ ChannelMonitorEnabled bool `json:"channel_monitor_enabled"`
+ ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"`
}
type WeChatConnectOAuthConfig struct {
diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go
index ce933798..5d8d88d2 100644
--- a/backend/internal/service/wire.go
+++ b/backend/internal/service/wire.go
@@ -500,8 +500,9 @@ func ProvideChannelMonitorService(
// ProvideChannelMonitorRunner 创建并启动渠道监控调度器。
// Runner.Stop 由 cleanup function 调用。
-func ProvideChannelMonitorRunner(svc *ChannelMonitorService) *ChannelMonitorRunner {
- r := NewChannelMonitorRunner(svc)
+// settingService 用于 runner 每个 tick 读取功能开关。
+func ProvideChannelMonitorRunner(svc *ChannelMonitorService, settingService *SettingService) *ChannelMonitorRunner {
+ r := NewChannelMonitorRunner(svc, settingService)
r.Start()
return r
}
diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts
index 0403b0f3..ab85c30c 100644
--- a/frontend/src/api/admin/settings.ts
+++ b/frontend/src/api/admin/settings.ts
@@ -469,6 +469,10 @@ export interface SystemSettings {
balance_low_notify_recharge_url: string;
account_quota_notify_enabled: boolean;
account_quota_notify_emails: NotifyEmailEntry[];
+
+ // Channel Monitor feature switch
+ channel_monitor_enabled: boolean;
+ channel_monitor_default_interval_seconds: number;
}
export interface UpdateSettingsRequest {
@@ -618,6 +622,10 @@ export interface UpdateSettingsRequest {
balance_low_notify_recharge_url?: string;
account_quota_notify_enabled?: boolean;
account_quota_notify_emails?: NotifyEmailEntry[];
+
+ // Channel Monitor feature switch
+ channel_monitor_enabled?: boolean;
+ channel_monitor_default_interval_seconds?: number;
}
/**
diff --git a/frontend/src/components/admin/channel/ModelTagInput.vue b/frontend/src/components/admin/channel/ModelTagInput.vue
index a1ce4022..b91aa119 100644
--- a/frontend/src/components/admin/channel/ModelTagInput.vue
+++ b/frontend/src/components/admin/channel/ModelTagInput.vue
@@ -27,6 +27,7 @@
@keydown.tab.prevent="addModel"
@keydown.delete="handleBackspace"
@paste="handlePaste"
+ @blur="addModel"
/>
diff --git a/frontend/src/components/admin/monitor/MonitorFormDialog.vue b/frontend/src/components/admin/monitor/MonitorFormDialog.vue
index 920c3f79..e1489ffb 100644
--- a/frontend/src/components/admin/monitor/MonitorFormDialog.vue
+++ b/frontend/src/components/admin/monitor/MonitorFormDialog.vue
@@ -143,6 +143,13 @@ const emit = defineEmits<{
const { t } = useI18n()
const appStore = useAppStore()
+// System-configured default interval for new monitors. Falls back to the static
+// constant when public settings haven't loaded yet or store the legacy 0 value.
+const systemDefaultInterval = computed(() => {
+ const configured = appStore.cachedPublicSettings?.channel_monitor_default_interval_seconds
+ return configured && configured > 0 ? configured : DEFAULT_INTERVAL_SECONDS
+})
+
// editing is true when we have an existing monitor
const editing = computed(() => props.monitor)
@@ -173,7 +180,7 @@ const form = reactive({
primary_model: '',
extra_models: [],
group_name: '',
- interval_seconds: DEFAULT_INTERVAL_SECONDS,
+ interval_seconds: systemDefaultInterval.value,
enabled: true,
})
@@ -191,7 +198,7 @@ function resetForm() {
form.primary_model = ''
form.extra_models = []
form.group_name = ''
- form.interval_seconds = DEFAULT_INTERVAL_SECONDS
+ form.interval_seconds = systemDefaultInterval.value
form.enabled = true
}
@@ -203,7 +210,7 @@ function loadFromMonitor(m: ChannelMonitor) {
form.primary_model = m.primary_model
form.extra_models = [...(m.extra_models || [])]
form.group_name = m.group_name || ''
- form.interval_seconds = m.interval_seconds || DEFAULT_INTERVAL_SECONDS
+ form.interval_seconds = m.interval_seconds || systemDefaultInterval.value
form.enabled = m.enabled
}
diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts
index b95c8b44..5e1ee189 100644
--- a/frontend/src/i18n/locales/en.ts
+++ b/frontend/src/i18n/locales/en.ts
@@ -4530,6 +4530,7 @@ export default {
description: 'Manage registration, email verification, default values, and SMTP settings',
tabs: {
general: 'General',
+ features: 'Feature Switches',
security: 'Security',
users: 'Users',
gateway: 'Gateway',
@@ -4537,6 +4538,16 @@ export default {
backup: 'Backup',
payment: 'Payment',
},
+ features: {
+ channelMonitor: {
+ title: 'Channel Monitor',
+ description: 'Periodically probe configured channels and surface availability / latency to users. Turning it off stops the scheduler and returns an empty list on the user page.',
+ enabled: 'Enable Channel Monitor',
+ enabledHint: 'Disabling stops background checks; existing history is preserved.',
+ defaultInterval: 'Default check interval (seconds)',
+ defaultIntervalHint: 'Pre-fills the interval when creating a new monitor; each monitor can override it. Range 15 – 3600.',
+ },
+ },
emailTabDisabledTitle: 'Email Verification Not Enabled',
emailTabDisabledHint: 'Enable email verification in the Security tab to configure SMTP settings.',
registration: {
diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts
index 54bc03c5..df569d54 100644
--- a/frontend/src/i18n/locales/zh.ts
+++ b/frontend/src/i18n/locales/zh.ts
@@ -4695,6 +4695,7 @@ export default {
description: '管理注册、邮箱验证、默认值和 SMTP 设置',
tabs: {
general: '通用设置',
+ features: '功能开关',
security: '安全与认证',
users: '用户默认值',
gateway: '网关服务',
@@ -4702,6 +4703,16 @@ export default {
backup: '数据备份',
payment: '支付设置',
},
+ features: {
+ channelMonitor: {
+ title: '渠道监控',
+ description: '定期对配置的渠道发起健康检查,向用户展示可用性与延迟。关闭后调度器停止扫描,用户端列表为空。',
+ enabled: '启用渠道监控',
+ enabledHint: '关闭后后台不再执行定时检测,已有数据保留。',
+ defaultInterval: '默认检测间隔(秒)',
+ defaultIntervalHint: '新建渠道监控时表单的默认值,可被单个渠道覆盖。范围 15 – 3600 秒。',
+ },
+ },
emailTabDisabledTitle: '邮箱验证未启用',
emailTabDisabledHint: '请在「安全与认证」选项卡中启用邮箱验证后,再配置 SMTP 设置。',
registration: {
diff --git a/frontend/src/stores/app.ts b/frontend/src/stores/app.ts
index f5c39e62..406ccdb7 100644
--- a/frontend/src/stores/app.ts
+++ b/frontend/src/stores/app.ts
@@ -352,6 +352,8 @@ export const useAppStore = defineStore('app', () => {
balance_low_notify_enabled: false,
account_quota_notify_enabled: false,
balance_low_notify_threshold: 0,
+ channel_monitor_enabled: true,
+ channel_monitor_default_interval_seconds: 60,
}
}
diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts
index 4587b60a..6f0d9181 100644
--- a/frontend/src/types/index.ts
+++ b/frontend/src/types/index.ts
@@ -185,6 +185,8 @@ export interface PublicSettings {
balance_low_notify_enabled: boolean
account_quota_notify_enabled: boolean
balance_low_notify_threshold: number
+ channel_monitor_enabled: boolean
+ channel_monitor_default_interval_seconds: number
}
export interface AuthResponse {
diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue
index f3dd653d..b1bf447e 100644
--- a/frontend/src/views/admin/SettingsView.vue
+++ b/frontend/src/views/admin/SettingsView.vue
@@ -3749,6 +3749,52 @@
+
+
+
+
+
+
+ {{ t('admin.settings.features.channelMonitor.title') }}
+
+
+ {{ t('admin.settings.features.channelMonitor.description') }}
+
+
+
+
+
+
+ {{ t('admin.settings.features.channelMonitor.enabled') }}
+
+
+ {{ t('admin.settings.features.channelMonitor.enabledHint') }}
+
+
+
+
+
+
+
+ {{ t('admin.settings.features.channelMonitor.defaultInterval') }}
+ *
+
+
+
+ {{ t('admin.settings.features.channelMonitor.defaultIntervalHint') }}
+
+
+
+
+
+
+
@@ -4737,6 +4783,7 @@ const paymentMethodsHref = computed(() =>
type SettingsTab =
| "general"
+ | "features"
| "security"
| "users"
| "gateway"
@@ -4746,6 +4793,7 @@ type SettingsTab =
const activeTab = ref
("general");
const settingsTabs = [
{ key: "general" as SettingsTab, icon: "home" as const },
+ { key: "features" as SettingsTab, icon: "bolt" as const },
{ key: "security" as SettingsTab, icon: "shield" as const },
{ key: "users" as SettingsTab, icon: "user" as const },
{ key: "gateway" as SettingsTab, icon: "server" as const },
@@ -5005,6 +5053,9 @@ const form = reactive({
balance_low_notify_recharge_url: "",
account_quota_notify_enabled: false,
account_quota_notify_emails: [] as NotifyEmailEntry[],
+ // Channel Monitor feature switch
+ channel_monitor_enabled: true,
+ channel_monitor_default_interval_seconds: 60,
});
const authSourceDefaults = reactive(
@@ -5912,6 +5963,10 @@ async function saveSettings() {
account_quota_notify_emails: (
form.account_quota_notify_emails || []
).filter((e) => e.email.trim() !== ""),
+ // Channel Monitor feature switch
+ channel_monitor_enabled: form.channel_monitor_enabled,
+ channel_monitor_default_interval_seconds:
+ Number(form.channel_monitor_default_interval_seconds) || 60,
};
appendAuthSourceDefaultsToUpdateRequest(payload, authSourceDefaults);
--
GitLab
From bf3ef2d19aafa4210c15bdead68df7403098ee66 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 00:22:17 +0800
Subject: [PATCH 085/261] add admin user last used support
---
backend/internal/handler/dto/mappers.go | 1 +
backend/internal/handler/dto/types.go | 3 +-
.../handler/dto/user_mapper_activity_test.go | 4 +
backend/internal/repository/user_repo.go | 115 ++++++++-----
.../user_repo_sort_integration_test.go | 63 +++++++
backend/internal/service/admin_service.go | 20 +++
.../service/admin_service_list_users_test.go | 41 +++++
backend/internal/service/user.go | 1 +
frontend/src/types/index.ts | 1 +
frontend/src/views/admin/UsersView.vue | 9 +-
.../views/admin/__tests__/UsersView.spec.ts | 162 ++++++++++++++++++
11 files changed, 373 insertions(+), 47 deletions(-)
create mode 100644 frontend/src/views/admin/__tests__/UsersView.spec.ts
diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go
index 2f18c64e..d88c110c 100644
--- a/backend/internal/handler/dto/mappers.go
+++ b/backend/internal/handler/dto/mappers.go
@@ -68,6 +68,7 @@ func UserFromServiceAdmin(u *service.User) *AdminUser {
return &AdminUser{
User: *base,
Notes: u.Notes,
+ LastUsedAt: u.LastUsedAt,
GroupRates: u.GroupRates,
}
}
diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go
index 72fce1fe..15b8548a 100644
--- a/backend/internal/handler/dto/types.go
+++ b/backend/internal/handler/dto/types.go
@@ -36,7 +36,8 @@ type User struct {
type AdminUser struct {
User
- Notes string `json:"notes"`
+ Notes string `json:"notes"`
+ LastUsedAt *time.Time `json:"last_used_at"`
// GroupRates 用户专属分组倍率配置
// map[groupID]rateMultiplier
GroupRates map[int64]float64 `json:"group_rates,omitempty"`
diff --git a/backend/internal/handler/dto/user_mapper_activity_test.go b/backend/internal/handler/dto/user_mapper_activity_test.go
index 668f886c..1e362fba 100644
--- a/backend/internal/handler/dto/user_mapper_activity_test.go
+++ b/backend/internal/handler/dto/user_mapper_activity_test.go
@@ -13,6 +13,7 @@ func TestUserFromServiceAdmin_MapsActivityTimestamps(t *testing.T) {
lastLoginAt := time.Date(2026, time.April, 20, 10, 0, 0, 0, time.UTC)
lastActiveAt := lastLoginAt.Add(15 * time.Minute)
+ lastUsedAt := lastLoginAt.Add(45 * time.Minute)
out := UserFromServiceAdmin(&service.User{
ID: 42,
@@ -22,11 +23,14 @@ func TestUserFromServiceAdmin_MapsActivityTimestamps(t *testing.T) {
Status: service.StatusActive,
LastLoginAt: &lastLoginAt,
LastActiveAt: &lastActiveAt,
+ LastUsedAt: &lastUsedAt,
})
require.NotNil(t, out)
require.NotNil(t, out.LastLoginAt)
require.NotNil(t, out.LastActiveAt)
+ require.NotNil(t, out.LastUsedAt)
require.WithinDuration(t, lastLoginAt, *out.LastLoginAt, time.Second)
require.WithinDuration(t, lastActiveAt, *out.LastActiveAt, time.Second)
+ require.WithinDuration(t, lastUsedAt, *out.LastUsedAt, time.Second)
}
diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go
index 7378611d..25d3f1d6 100644
--- a/backend/internal/repository/user_repo.go
+++ b/backend/internal/repository/user_repo.go
@@ -299,51 +299,6 @@ func normalizeEmailAuthIdentitySubject(email string) string {
return normalized
}
-func (r *userRepository) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) {
- result := make(map[int64]*time.Time, len(userIDs))
- if len(userIDs) == 0 {
- return result, nil
- }
- if r.sql == nil {
- return nil, fmt.Errorf("sql executor is not configured")
- }
-
- rows, err := r.sql.QueryContext(ctx, `
- SELECT user_id, MAX(created_at) AS last_used_at
- FROM usage_logs
- WHERE user_id = ANY($1)
- GROUP BY user_id
- `, pq.Array(userIDs))
- if err != nil {
- return nil, err
- }
- defer func() { _ = rows.Close() }()
-
- for rows.Next() {
- var (
- userID int64
- lastUsedAt time.Time
- )
- if err := rows.Scan(&userID, &lastUsedAt); err != nil {
- return nil, err
- }
- ts := lastUsedAt.UTC()
- result[userID] = &ts
- }
- if err := rows.Err(); err != nil {
- return nil, err
- }
- return result, nil
-}
-
-func (r *userRepository) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) {
- latestByUserID, err := r.GetLatestUsedAtByUserIDs(ctx, []int64{userID})
- if err != nil {
- return nil, err
- }
- return latestByUserID[userID], nil
-}
-
func (r *userRepository) Delete(ctx context.Context, id int64) error {
affected, err := r.client.User.Delete().Where(dbuser.IDEQ(id)).Exec(ctx)
if err != nil {
@@ -469,6 +424,10 @@ func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector)
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
+ if sortBy == "last_used_at" {
+ return userLastUsedAtOrder(sortOrder)
+ }
+
var field string
defaultField := true
nullsLastField := false
@@ -530,6 +489,72 @@ func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector)
return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(dbuser.FieldID)}
}
+func (r *userRepository) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) {
+ result := make(map[int64]*time.Time, len(userIDs))
+ if len(userIDs) == 0 {
+ return result, nil
+ }
+ if r.sql == nil {
+ return nil, fmt.Errorf("sql executor is not configured")
+ }
+
+ const query = `
+ SELECT user_id, MAX(created_at) AS last_used_at
+ FROM usage_logs
+ WHERE user_id = ANY($1)
+ GROUP BY user_id
+ `
+
+ rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs))
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = rows.Close() }()
+
+ for rows.Next() {
+ var (
+ userID int64
+ lastUsedAt time.Time
+ )
+ if scanErr := rows.Scan(&userID, &lastUsedAt); scanErr != nil {
+ return nil, scanErr
+ }
+ ts := lastUsedAt.UTC()
+ result[userID] = &ts
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return result, nil
+}
+
+func (r *userRepository) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) {
+ latestByUserID, err := r.GetLatestUsedAtByUserIDs(ctx, []int64{userID})
+ if err != nil {
+ return nil, err
+ }
+ return latestByUserID[userID], nil
+}
+
+func userLastUsedAtOrder(sortOrder string) []func(*entsql.Selector) {
+ orderExpr := func(direction, nulls string, tieOrder func(string) string) func(*entsql.Selector) {
+ return func(s *entsql.Selector) {
+ subquery := fmt.Sprintf("(SELECT MAX(created_at) FROM usage_logs WHERE user_id = %s)", s.C(dbuser.FieldID))
+ s.OrderExpr(entsql.Expr(subquery + " " + direction + " NULLS " + nulls))
+ s.OrderBy(tieOrder(s.C(dbuser.FieldID)))
+ }
+ }
+
+ if sortOrder == pagination.SortOrderAsc {
+ return []func(*entsql.Selector){
+ orderExpr("ASC", "FIRST", entsql.Asc),
+ }
+ }
+ return []func(*entsql.Selector){
+ orderExpr("DESC", "LAST", entsql.Desc),
+ }
+}
+
// filterUsersByAttributes returns user IDs that match ALL the given attribute filters
func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[int64]string) ([]int64, error) {
if len(attrs) == 0 {
diff --git a/backend/internal/repository/user_repo_sort_integration_test.go b/backend/internal/repository/user_repo_sort_integration_test.go
index 8abef45a..e2445d5b 100644
--- a/backend/internal/repository/user_repo_sort_integration_test.go
+++ b/backend/internal/repository/user_repo_sort_integration_test.go
@@ -10,6 +10,24 @@ import (
"github.com/Wei-Shaw/sub2api/internal/service"
)
+func (s *UserRepoSuite) mustInsertUsageLog(userID int64, createdAt time.Time) {
+ s.T().Helper()
+
+ account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "usage-log-account"})
+ apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: userID})
+
+ _, err := integrationDB.ExecContext(
+ s.ctx,
+ `INSERT INTO usage_logs (user_id, api_key_id, account_id, model, input_tokens, output_tokens, total_cost, actual_cost, created_at)
+ VALUES ($1, $2, $3, 'gpt-test', 1, 1, 0.01, 0.01, $4)`,
+ userID,
+ apiKey.ID,
+ account.ID,
+ createdAt.UTC(),
+ )
+ s.Require().NoError(err)
+}
+
func (s *UserRepoSuite) TestListWithFilters_SortByEmailAsc() {
s.mustCreateUser(&service.User{Email: "z-last@example.com", Username: "z-user"})
s.mustCreateUser(&service.User{Email: "a-first@example.com", Username: "a-user"})
@@ -119,4 +137,49 @@ func (s *UserRepoSuite) TestListWithFilters_SortByLastActiveAtAsc() {
s.Require().Equal("nil-active@example.com", users[2].Email)
}
+func (s *UserRepoSuite) TestGetLatestUsedAtByUserIDs_UsesUsageLogs() {
+ older := time.Now().Add(-4 * time.Hour).UTC().Truncate(time.Second)
+ newer := time.Now().Add(-90 * time.Minute).UTC().Truncate(time.Second)
+
+ userWithUsage := s.mustCreateUser(&service.User{Email: "usage-source@example.com"})
+ userWithoutUsage := s.mustCreateUser(&service.User{Email: "usage-missing@example.com"})
+ s.mustInsertUsageLog(userWithUsage.ID, older)
+ s.mustInsertUsageLog(userWithUsage.ID, newer)
+
+ got, err := s.repo.GetLatestUsedAtByUserIDs(s.ctx, []int64{userWithUsage.ID, userWithoutUsage.ID})
+ s.Require().NoError(err)
+ s.Require().Contains(got, userWithUsage.ID)
+ s.Require().NotContains(got, userWithoutUsage.ID)
+ s.Require().NotNil(got[userWithUsage.ID])
+ s.Require().True(got[userWithUsage.ID].Equal(newer))
+}
+
+func (s *UserRepoSuite) TestListWithFilters_SortByLastUsedAtDesc_UsesUsageLogsNotLastActiveAt() {
+ lastUsedOlder := time.Now().Add(-6 * time.Hour).UTC().Truncate(time.Second)
+ lastUsedNewer := time.Now().Add(-2 * time.Hour).UTC().Truncate(time.Second)
+ lastActiveVeryRecent := time.Now().Add(-10 * time.Minute).UTC().Truncate(time.Second)
+
+ nilUsage := s.mustCreateUser(&service.User{Email: "nil-last-used@example.com"})
+ wrongSource := s.mustCreateUser(&service.User{
+ Email: "active-not-usage@example.com",
+ LastActiveAt: &lastActiveVeryRecent,
+ })
+ rightSource := s.mustCreateUser(&service.User{Email: "usage-wins@example.com"})
+
+ s.mustInsertUsageLog(wrongSource.ID, lastUsedOlder)
+ s.mustInsertUsageLog(rightSource.ID, lastUsedNewer)
+
+ users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{
+ Page: 1,
+ PageSize: 10,
+ SortBy: "last_used_at",
+ SortOrder: "desc",
+ }, service.UserListFilters{})
+ s.Require().NoError(err)
+ s.Require().Len(users, 3)
+ s.Require().Equal(rightSource.ID, users[0].ID)
+ s.Require().Equal(wrongSource.ID, users[1].ID)
+ s.Require().Equal(nilUsage.ID, users[2].ID)
+}
+
func TestUserRepoSortSuiteSmoke(_ *testing.T) {}
diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go
index 79840e5b..10b85f76 100644
--- a/backend/internal/service/admin_service.go
+++ b/backend/internal/service/admin_service.go
@@ -557,6 +557,20 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, fi
if err != nil {
return nil, 0, err
}
+ if len(users) > 0 {
+ userIDs := make([]int64, 0, len(users))
+ for i := range users {
+ userIDs = append(userIDs, users[i].ID)
+ }
+ lastUsedByUserID, latestErr := s.userRepo.GetLatestUsedAtByUserIDs(ctx, userIDs)
+ if latestErr != nil {
+ logger.LegacyPrintf("service.admin", "failed to load user last_used_at in batch: err=%v", latestErr)
+ } else {
+ for i := range users {
+ users[i].LastUsedAt = lastUsedByUserID[users[i].ID]
+ }
+ }
+ }
// 批量加载用户专属分组倍率
if s.userGroupRateRepo != nil && len(users) > 0 {
if batchRepo, ok := s.userGroupRateRepo.(userGroupRateBatchReader); ok {
@@ -601,6 +615,12 @@ func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error)
if err != nil {
return nil, err
}
+ lastUsedAt, latestErr := s.userRepo.GetLatestUsedAtByUserID(ctx, id)
+ if latestErr != nil {
+ logger.LegacyPrintf("service.admin", "failed to load user last_used_at: user_id=%d err=%v", id, latestErr)
+ } else {
+ user.LastUsedAt = lastUsedAt
+ }
// 加载用户专属分组倍率
if s.userGroupRateRepo != nil {
rates, err := s.userGroupRateRepo.GetByUserID(ctx, id)
diff --git a/backend/internal/service/admin_service_list_users_test.go b/backend/internal/service/admin_service_list_users_test.go
index ceeb52c2..657616c4 100644
--- a/backend/internal/service/admin_service_list_users_test.go
+++ b/backend/internal/service/admin_service_list_users_test.go
@@ -6,6 +6,7 @@ import (
"context"
"errors"
"testing"
+ "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
@@ -16,6 +17,8 @@ type userRepoStubForListUsers struct {
users []User
err error
listWithFiltersParams pagination.PaginationParams
+ lastUsedByUserID map[int64]*time.Time
+ lastUsedErr error
}
func (s *userRepoStubForListUsers) ListWithFilters(_ context.Context, params pagination.PaginationParams, _ UserListFilters) ([]User, *pagination.PaginationResult, error) {
@@ -32,6 +35,26 @@ func (s *userRepoStubForListUsers) ListWithFilters(_ context.Context, params pag
}, nil
}
+func (s *userRepoStubForListUsers) GetLatestUsedAtByUserIDs(_ context.Context, userIDs []int64) (map[int64]*time.Time, error) {
+ if s.lastUsedErr != nil {
+ return nil, s.lastUsedErr
+ }
+ result := make(map[int64]*time.Time, len(userIDs))
+ for _, userID := range userIDs {
+ if ts, ok := s.lastUsedByUserID[userID]; ok {
+ result[userID] = ts
+ }
+ }
+ return result, nil
+}
+
+func (s *userRepoStubForListUsers) GetLatestUsedAtByUserID(_ context.Context, userID int64) (*time.Time, error) {
+ if s.lastUsedErr != nil {
+ return nil, s.lastUsedErr
+ }
+ return s.lastUsedByUserID[userID], nil
+}
+
type userGroupRateRepoStubForListUsers struct {
batchCalls int
singleCall []int64
@@ -130,3 +153,21 @@ func TestAdminService_ListUsers_PassesSortParams(t *testing.T) {
SortOrder: "ASC",
}, userRepo.listWithFiltersParams)
}
+
+func TestAdminService_ListUsers_PopulatesLastUsedAt(t *testing.T) {
+ lastUsed := time.Now().UTC().Add(-30 * time.Minute).Truncate(time.Second)
+ userRepo := &userRepoStubForListUsers{
+ users: []User{{ID: 101, Email: "u@example.com"}},
+ lastUsedByUserID: map[int64]*time.Time{
+ 101: &lastUsed,
+ },
+ }
+ svc := &adminServiceImpl{userRepo: userRepo}
+
+ users, total, err := svc.ListUsers(context.Background(), 1, 20, UserListFilters{}, "", "")
+ require.NoError(t, err)
+ require.Equal(t, int64(1), total)
+ require.Len(t, users, 1)
+ require.NotNil(t, users[0].LastUsedAt)
+ require.WithinDuration(t, lastUsed, *users[0].LastUsedAt, time.Second)
+}
diff --git a/backend/internal/service/user.go b/backend/internal/service/user.go
index d8b5325c..fa04d95e 100644
--- a/backend/internal/service/user.go
+++ b/backend/internal/service/user.go
@@ -26,6 +26,7 @@ type User struct {
SignupSource string
LastLoginAt *time.Time
LastActiveAt *time.Time
+ LastUsedAt *time.Time
CreatedAt time.Time
UpdatedAt time.Time
diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts
index a4b2277b..5a2e3184 100644
--- a/frontend/src/types/index.ts
+++ b/frontend/src/types/index.ts
@@ -93,6 +93,7 @@ export interface User {
export interface AdminUser extends User {
// 管理员备注(普通用户接口不返回)
notes: string
+ last_used_at?: string | null
// 用户专属分组倍率配置 (group_id -> rate_multiplier)
group_rates?: Record
// 当前并发数(仅管理员列表接口返回)
diff --git a/frontend/src/views/admin/UsersView.vue b/frontend/src/views/admin/UsersView.vue
index 55a8f2d8..07c9d437 100644
--- a/frontend/src/views/admin/UsersView.vue
+++ b/frontend/src/views/admin/UsersView.vue
@@ -461,6 +461,12 @@
+
+
+ {{ value ? formatDateTime(value) : '-' }}
+
+
+
{{ value ? formatDateTime(value) : '-' }}
@@ -713,6 +719,7 @@ const allColumns = computed(() => [
{ key: 'concurrency', label: t('admin.users.columns.concurrency'), sortable: true },
{ key: 'status', label: t('admin.users.columns.status'), sortable: true },
{ key: 'last_login_at', label: t('admin.users.columns.lastLogin'), sortable: true },
+ { key: 'last_used_at', label: t('admin.users.columns.lastUsed'), sortable: true },
{ key: 'last_active_at', label: t('admin.users.columns.lastActive'), sortable: true },
{ key: 'created_at', label: t('admin.users.columns.created'), sortable: true },
{ key: 'actions', label: t('admin.users.columns.actions'), sortable: false }
@@ -801,7 +808,7 @@ const searchQuery = ref('')
const USER_SORT_STORAGE_KEY = 'admin-users-table-sort'
const loadInitialSortState = (): { sort_by: string; sort_order: 'asc' | 'desc' } => {
const fallback = { sort_by: 'created_at', sort_order: 'desc' as 'asc' | 'desc' }
- const sortable = new Set(['email', 'id', 'username', 'role', 'balance', 'concurrency', 'status', 'last_login_at', 'last_active_at', 'created_at'])
+ const sortable = new Set(['email', 'id', 'username', 'role', 'balance', 'concurrency', 'status', 'last_login_at', 'last_used_at', 'last_active_at', 'created_at'])
try {
const raw = localStorage.getItem(USER_SORT_STORAGE_KEY)
if (!raw) return fallback
diff --git a/frontend/src/views/admin/__tests__/UsersView.spec.ts b/frontend/src/views/admin/__tests__/UsersView.spec.ts
new file mode 100644
index 00000000..1ea67b63
--- /dev/null
+++ b/frontend/src/views/admin/__tests__/UsersView.spec.ts
@@ -0,0 +1,162 @@
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+import { flushPromises, mount } from '@vue/test-utils'
+
+import type { AdminUser } from '@/types'
+import UsersView from '../UsersView.vue'
+
+const {
+ listUsers,
+ getAllGroups,
+ getBatchUsersUsage,
+ listEnabledDefinitions,
+ getBatchUserAttributes
+} = vi.hoisted(() => ({
+ listUsers: vi.fn(),
+ getAllGroups: vi.fn(),
+ getBatchUsersUsage: vi.fn(),
+ listEnabledDefinitions: vi.fn(),
+ getBatchUserAttributes: vi.fn()
+}))
+
+vi.mock('@/api/admin', () => ({
+ adminAPI: {
+ users: {
+ list: listUsers,
+ toggleStatus: vi.fn(),
+ delete: vi.fn()
+ },
+ groups: {
+ getAll: getAllGroups
+ },
+ dashboard: {
+ getBatchUsersUsage
+ },
+ userAttributes: {
+ listEnabledDefinitions,
+ getBatchUserAttributes
+ }
+ }
+}))
+
+vi.mock('@/stores/app', () => ({
+ useAppStore: () => ({
+ showError: vi.fn(),
+ showSuccess: vi.fn()
+ })
+}))
+
+vi.mock('vue-i18n', async () => {
+ const actual = await vi.importActual('vue-i18n')
+ return {
+ ...actual,
+ useI18n: () => ({
+ t: (key: string) => key
+ })
+ }
+})
+
+const createAdminUser = (): AdminUser => ({
+ id: 42,
+ username: 'scoped-user',
+ email: 'scoped@example.com',
+ role: 'user',
+ balance: 0,
+ concurrency: 1,
+ status: 'active',
+ allowed_groups: [],
+ balance_notify_enabled: false,
+ balance_notify_threshold: null,
+ balance_notify_extra_emails: [],
+ created_at: '2026-04-17T00:00:00Z',
+ updated_at: '2026-04-17T00:00:00Z',
+ notes: '',
+ last_login_at: '2026-04-16T01:00:00Z',
+ last_active_at: '2026-04-16T02:00:00Z',
+ last_used_at: '2026-04-17T02:00:00Z',
+ current_concurrency: 0
+})
+
+const DataTableStub = {
+ props: ['columns', 'data'],
+ emits: ['sort'],
+ template: `
+
+
{{ columns.map(col => col.key).join(',') }}
+
sort
+
+
+
+
+ `
+}
+
+describe('admin UsersView', () => {
+ beforeEach(() => {
+ localStorage.clear()
+
+ listUsers.mockReset()
+ getAllGroups.mockReset()
+ getBatchUsersUsage.mockReset()
+ listEnabledDefinitions.mockReset()
+ getBatchUserAttributes.mockReset()
+
+ listUsers.mockResolvedValue({
+ items: [createAdminUser()],
+ total: 1,
+ page: 1,
+ page_size: 20,
+ pages: 1
+ })
+ getAllGroups.mockResolvedValue([])
+ getBatchUsersUsage.mockResolvedValue({ stats: {} })
+ listEnabledDefinitions.mockResolvedValue([])
+ getBatchUserAttributes.mockResolvedValue({ values: {} })
+ })
+
+ it('shows last_used_at column and requests last_used_at sort', async () => {
+ const wrapper = mount(UsersView, {
+ global: {
+ stubs: {
+ AppLayout: { template: '
' },
+ TablePageLayout: {
+ template: '
'
+ },
+ DataTable: DataTableStub,
+ Pagination: true,
+ ConfirmDialog: true,
+ EmptyState: true,
+ GroupBadge: true,
+ Select: true,
+ UserAttributesConfigModal: true,
+ UserConcurrencyCell: true,
+ UserCreateModal: true,
+ UserEditModal: true,
+ UserApiKeysModal: true,
+ UserAllowedGroupsModal: true,
+ UserBalanceModal: true,
+ UserBalanceHistoryModal: true,
+ GroupReplaceModal: true,
+ Icon: true,
+ Teleport: true
+ }
+ }
+ })
+
+ await flushPromises()
+
+ expect(wrapper.get('[data-test="columns"]').text()).toContain('last_used_at')
+
+ await wrapper.get('[data-test="sort-last-used"]').trigger('click')
+ await flushPromises()
+
+ expect(listUsers).toHaveBeenLastCalledWith(
+ 1,
+ 20,
+ expect.objectContaining({
+ sort_by: 'last_used_at',
+ sort_order: 'desc'
+ }),
+ expect.any(Object)
+ )
+ })
+})
--
GitLab
From 0d01bd908eb8bd01d222501870db3b1b3dd44faa Mon Sep 17 00:00:00 2001
From: erio
Date: Tue, 21 Apr 2026 00:27:07 +0800
Subject: [PATCH 086/261] refactor(channel-monitor): remove INTELLIGENCE
MONITOR hero title
Subtitle + breadcrumb already convey context; the giant h1 was visual
noise. Drops orphan i18n key `channelStatus.hero.title` and shrinks
hero section vertical padding accordingly.
Bump VERSION to 0.1.114.26
---
frontend/src/components/user/monitor/MonitorHero.vue | 9 ++-------
frontend/src/i18n/locales/en.ts | 1 -
frontend/src/i18n/locales/zh.ts | 1 -
3 files changed, 2 insertions(+), 9 deletions(-)
diff --git a/frontend/src/components/user/monitor/MonitorHero.vue b/frontend/src/components/user/monitor/MonitorHero.vue
index be5a96b8..6857a6fe 100644
--- a/frontend/src/components/user/monitor/MonitorHero.vue
+++ b/frontend/src/components/user/monitor/MonitorHero.vue
@@ -1,16 +1,11 @@
-
+
{{ t('channelStatus.hero.breadcrumb') }}
-
- {{ t('channelStatus.hero.title') }}
-
-
+
{{ t('channelStatus.hero.subtitleZh') }}
diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts
index 5e1ee189..51aa9920 100644
--- a/frontend/src/i18n/locales/en.ts
+++ b/frontend/src/i18n/locales/en.ts
@@ -897,7 +897,6 @@ export default {
closeDetail: 'Close',
hero: {
breadcrumb: 'CHANNEL · STATUS',
- title: 'INTELLIGENCE MONITOR',
subtitleZh: 'Real-time tracking of availability, latency and status for leading AI endpoints.',
subtitleEn: 'Advanced performance metrics for next-gen intelligence.'
},
diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts
index df569d54..021b8992 100644
--- a/frontend/src/i18n/locales/zh.ts
+++ b/frontend/src/i18n/locales/zh.ts
@@ -901,7 +901,6 @@ export default {
closeDetail: '关闭',
hero: {
breadcrumb: '渠道 · 状态',
- title: 'INTELLIGENCE MONITOR',
subtitleZh: '实时追踪各大 AI 模型对话接口的可用性、延迟与官方服务状态。',
subtitleEn: 'Advanced performance metrics for next-gen intelligence.'
},
--
GitLab
From ba98243cc2bc230bffb528014331f70276dbffaa Mon Sep 17 00:00:00 2001
From: erio
Date: Tue, 21 Apr 2026 01:42:58 +0800
Subject: [PATCH 087/261] feat(channel-monitor): gate UI by feature switch +
polish form UX
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
- AppSidebar 三处菜单项(管理端渠道监控、用户端/个人页渠道状态)按
channel_monitor_enabled 条件展开,关闭时隐藏
- ChannelStatusView setInterval 随开关启停:关闭 clearInterval,
开启/未知态自动启动,避免禁用功能后仍在轮询
- MonitorFormDialog provider Select 改为 3 色单选按钮
(openai=emerald / anthropic=orange / gemini=sky),i18n 文案
供应商 → 平台 / Provider → Platform
- MonitorKeyPickerDialog 按钮列表改为 name/key/group 三列表格 +
搜索框,按 key.group.platform === provider 过滤,避免跨平台误选
- form.provider 变化时清空 api_key,修复切换平台仍保留旧 key 的
错配 bug
- providerPickerClass 抽取到 useChannelMonitorFormat composable,
统一 emerald/orange/sky 颜色语义,消除硬编码 Tailwind class 重复
- maskApiKey 工具函数统一(utils/maskApiKey.ts),KeysView 与
MonitorKeyPickerDialog 共用 slice(0,6)...slice(-4) 策略
- bump version to 0.1.114.27
---
.../admin/monitor/MonitorFormDialog.vue | 35 +++++++-
.../admin/monitor/MonitorKeyPickerDialog.vue | 84 ++++++++++++++-----
frontend/src/components/layout/AppSidebar.vue | 12 ++-
.../composables/useChannelMonitorFormat.ts | 27 ++++++
frontend/src/i18n/locales/en.ts | 2 +-
frontend/src/i18n/locales/zh.ts | 2 +-
frontend/src/utils/maskApiKey.ts | 6 ++
frontend/src/views/user/ChannelStatusView.vue | 24 +++++-
frontend/src/views/user/KeysView.vue | 8 +-
9 files changed, 165 insertions(+), 35 deletions(-)
create mode 100644 frontend/src/utils/maskApiKey.ts
diff --git a/frontend/src/components/admin/monitor/MonitorFormDialog.vue b/frontend/src/components/admin/monitor/MonitorFormDialog.vue
index e1489ffb..56a06a9f 100644
--- a/frontend/src/components/admin/monitor/MonitorFormDialog.vue
+++ b/frontend/src/components/admin/monitor/MonitorFormDialog.vue
@@ -13,7 +13,20 @@
{{ t('admin.channelMonitor.form.provider') }} *
-
+
+
+
+ {{ opt.label }}
+
+
@@ -99,6 +112,7 @@
:show="showKeyPicker"
:loading="myKeysLoading"
:keys="myActiveKeys"
+ :provider="form.provider"
@close="showKeyPicker = false"
@pick="pickMyKey"
/>
@@ -119,10 +133,11 @@ import type {
} from '@/api/admin/channelMonitor'
import type { ApiKey } from '@/types'
import BaseDialog from '@/components/common/BaseDialog.vue'
-import Select from '@/components/common/Select.vue'
import Toggle from '@/components/common/Toggle.vue'
import ModelTagInput from '@/components/admin/channel/ModelTagInput.vue'
import MonitorKeyPickerDialog from '@/components/admin/monitor/MonitorKeyPickerDialog.vue'
+import ProviderIcon from '@/components/user/monitor/ProviderIcon.vue'
+import { useChannelMonitorFormat } from '@/composables/useChannelMonitorFormat'
import {
PROVIDER_OPENAI,
PROVIDER_ANTHROPIC,
@@ -142,6 +157,7 @@ const emit = defineEmits<{
const { t } = useI18n()
const appStore = useAppStore()
+const { providerPickerClass } = useChannelMonitorFormat()
// System-configured default interval for new monitors. Falls back to the static
// constant when public settings haven't loaded yet or store the legacy 0 value.
@@ -184,12 +200,25 @@ const form = reactive
({
enabled: true,
})
-const providerOptions = computed(() => [
+interface ProviderOption {
+ value: Provider
+ label: string
+}
+
+const providerOptions = computed(() => [
{ value: PROVIDER_OPENAI, label: t('monitorCommon.providers.openai') },
{ value: PROVIDER_ANTHROPIC, label: t('monitorCommon.providers.anthropic') },
{ value: PROVIDER_GEMINI, label: t('monitorCommon.providers.gemini') },
])
+// Clear api_key whenever provider changes to avoid cross-provider key mismatch.
+// Editing mode loads api_key='' via loadFromMonitor and only sets it on user
+// typing, so clearing on provider change is always a safe no-op until the user
+// picks a new key.
+watch(() => form.provider, () => {
+ form.api_key = ''
+})
+
function resetForm() {
form.name = ''
form.provider = PROVIDER_OPENAI
diff --git a/frontend/src/components/admin/monitor/MonitorKeyPickerDialog.vue b/frontend/src/components/admin/monitor/MonitorKeyPickerDialog.vue
index eefe4073..4fd71cb2 100644
--- a/frontend/src/components/admin/monitor/MonitorKeyPickerDialog.vue
+++ b/frontend/src/components/admin/monitor/MonitorKeyPickerDialog.vue
@@ -2,30 +2,59 @@
{{ t('admin.channelMonitor.form.selectKeyHint') }}
+
+
+
{{ t('common.loading') }}
-
+
{{ t('admin.channelMonitor.form.noActiveKey') }}
-
-
- {{ k.name }}
- {{ maskKey(k.key) }}
-
+
+
+
+
+ {{ t('common.name') }}
+ {{ t('keys.apiKey') }}
+ {{ t('keys.group') }}
+
+
+
+
+ {{ k.name }}
+ {{ maskApiKey(k.key) }}
+
+
+ {{ k.group.name }}
+
+ —
+
+
+
+
@@ -39,14 +68,18 @@
diff --git a/frontend/src/components/layout/AppSidebar.vue b/frontend/src/components/layout/AppSidebar.vue
index 23d0f4e9..8b9fbdea 100644
--- a/frontend/src/components/layout/AppSidebar.vue
+++ b/frontend/src/components/layout/AppSidebar.vue
@@ -611,7 +611,9 @@ const userNavItems = computed((): NavItem[] => {
{ path: '/dashboard', label: t('nav.dashboard'), icon: DashboardIcon },
{ path: '/keys', label: t('nav.apiKeys'), icon: KeyIcon },
{ path: '/usage', label: t('nav.usage'), icon: ChartIcon, hideInSimpleMode: true },
- { path: '/monitor', label: t('nav.channelStatus'), icon: SignalIcon },
+ ...(appStore.cachedPublicSettings?.channel_monitor_enabled
+ ? [{ path: '/monitor', label: t('nav.channelStatus'), icon: SignalIcon }]
+ : []),
{ path: '/subscriptions', label: t('nav.mySubscriptions'), icon: CreditCardIcon, hideInSimpleMode: true },
...(appStore.cachedPublicSettings?.payment_enabled
? [
@@ -650,7 +652,9 @@ const personalNavItems = computed((): NavItem[] => {
const items: NavItem[] = [
{ path: '/keys', label: t('nav.apiKeys'), icon: KeyIcon },
{ path: '/usage', label: t('nav.usage'), icon: ChartIcon, hideInSimpleMode: true },
- { path: '/monitor', label: t('nav.channelStatus'), icon: SignalIcon },
+ ...(appStore.cachedPublicSettings?.channel_monitor_enabled
+ ? [{ path: '/monitor', label: t('nav.channelStatus'), icon: SignalIcon }]
+ : []),
{ path: '/subscriptions', label: t('nav.mySubscriptions'), icon: CreditCardIcon, hideInSimpleMode: true },
...(appStore.cachedPublicSettings?.payment_enabled
? [
@@ -715,7 +719,9 @@ const adminNavItems = computed((): NavItem[] => {
expandOnly: true,
children: [
{ path: '/admin/channels/pricing', label: t('nav.channelPricing'), icon: PriceTagIcon },
- { path: '/admin/channels/monitor', label: t('nav.channelMonitor'), icon: SignalIcon },
+ ...(appStore.cachedPublicSettings?.channel_monitor_enabled
+ ? [{ path: '/admin/channels/monitor', label: t('nav.channelMonitor'), icon: SignalIcon }]
+ : []),
],
},
{ path: '/admin/subscriptions', label: t('nav.subscriptions'), icon: CreditCardIcon, hideInSimpleMode: true },
diff --git a/frontend/src/composables/useChannelMonitorFormat.ts b/frontend/src/composables/useChannelMonitorFormat.ts
index 7ffdaa42..a9253622 100644
--- a/frontend/src/composables/useChannelMonitorFormat.ts
+++ b/frontend/src/composables/useChannelMonitorFormat.ts
@@ -76,6 +76,32 @@ export function useChannelMonitorFormat() {
}
}
+ /**
+ * Tailwind class for a provider radio-button-style picker (active/inactive state).
+ * Reuses the same emerald/orange/sky palette as providerBadgeClass to keep
+ * visual semantics consistent across badges and pickers.
+ */
+ function providerPickerClass(p: Provider | string, active: boolean): string {
+ switch (p) {
+ case PROVIDER_OPENAI:
+ return active
+ ? 'border-emerald-500 bg-emerald-50 text-emerald-700 dark:bg-emerald-500/15 dark:text-emerald-300 dark:border-emerald-400'
+ : 'border-gray-200 bg-white text-gray-600 hover:border-emerald-300 hover:text-emerald-700 dark:border-dark-700 dark:bg-dark-800 dark:text-gray-400 dark:hover:border-emerald-500/50'
+ case PROVIDER_ANTHROPIC:
+ return active
+ ? 'border-orange-500 bg-orange-50 text-orange-700 dark:bg-orange-500/15 dark:text-orange-300 dark:border-orange-400'
+ : 'border-gray-200 bg-white text-gray-600 hover:border-orange-300 hover:text-orange-700 dark:border-dark-700 dark:bg-dark-800 dark:text-gray-400 dark:hover:border-orange-500/50'
+ case PROVIDER_GEMINI:
+ return active
+ ? 'border-sky-500 bg-sky-50 text-sky-700 dark:bg-sky-500/15 dark:text-sky-300 dark:border-sky-400'
+ : 'border-gray-200 bg-white text-gray-600 hover:border-sky-300 hover:text-sky-700 dark:border-dark-700 dark:bg-dark-800 dark:text-gray-400 dark:hover:border-sky-500/50'
+ default:
+ return active
+ ? 'border-gray-400 bg-gray-50 text-gray-700 dark:border-dark-500 dark:bg-dark-700 dark:text-gray-200'
+ : 'border-gray-200 bg-white text-gray-600 hover:border-gray-300 dark:border-dark-700 dark:bg-dark-800 dark:text-gray-400'
+ }
+ }
+
function formatLatency(ms: number | null | undefined): string {
if (ms == null) return t('monitorCommon.latencyEmpty')
return String(Math.round(ms))
@@ -110,6 +136,7 @@ export function useChannelMonitorFormat() {
statusBadgeClass,
providerLabel,
providerBadgeClass,
+ providerPickerClass,
formatLatency,
formatPercent,
formatAvailability,
diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts
index 51aa9920..ef3a4057 100644
--- a/frontend/src/i18n/locales/en.ts
+++ b/frontend/src/i18n/locales/en.ts
@@ -2137,7 +2137,7 @@ export default {
form: {
name: 'Name',
namePlaceholder: 'Enter monitor name',
- provider: 'Provider',
+ provider: 'Platform',
endpoint: 'Endpoint',
endpointPlaceholder: 'https://api.example.com',
useCurrentDomain: 'Use current service',
diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts
index 021b8992..25bce657 100644
--- a/frontend/src/i18n/locales/zh.ts
+++ b/frontend/src/i18n/locales/zh.ts
@@ -2216,7 +2216,7 @@ export default {
form: {
name: '名称',
namePlaceholder: '输入监控名称',
- provider: '供应商',
+ provider: '平台',
endpoint: '上游地址',
endpointPlaceholder: 'https://api.example.com',
useCurrentDomain: '使用当前服务',
diff --git a/frontend/src/utils/maskApiKey.ts b/frontend/src/utils/maskApiKey.ts
new file mode 100644
index 00000000..ab54980b
--- /dev/null
+++ b/frontend/src/utils/maskApiKey.ts
@@ -0,0 +1,6 @@
+// Mask an API key for display: reveals first 6 + last 4; short keys (≤12) show `first 4 + ***`.
+export function maskApiKey(key: string): string {
+ if (!key) return ''
+ if (key.length <= 12) return `${key.slice(0, 4)}***`
+ return `${key.slice(0, 6)}...${key.slice(-4)}`
+}
diff --git a/frontend/src/views/user/ChannelStatusView.vue b/frontend/src/views/user/ChannelStatusView.vue
index af427cca..d9100890 100644
--- a/frontend/src/views/user/ChannelStatusView.vue
+++ b/frontend/src/views/user/ChannelStatusView.vue
@@ -160,14 +160,34 @@ watch(items, () => {
void ensureDetailsForWindow()
})
+function startTimer() {
+ if (countdownTimer !== undefined) return
+ countdownTimer = setInterval(tick, 1000) as unknown as number
+}
+
+function stopTimer() {
+ if (countdownTimer !== undefined) {
+ clearInterval(countdownTimer)
+ countdownTimer = undefined
+ }
+}
+
+watch(
+ () => appStore.cachedPublicSettings?.channel_monitor_enabled,
+ (enabled) => {
+ if (enabled === false) stopTimer()
+ else startTimer()
+ },
+)
+
// ── Lifecycle ──
onMounted(() => {
void reload(false)
- countdownTimer = setInterval(tick, 1000) as unknown as number
+ if (appStore.cachedPublicSettings?.channel_monitor_enabled !== false) startTimer()
})
onBeforeUnmount(() => {
- if (countdownTimer !== undefined) clearInterval(countdownTimer)
+ stopTimer()
if (abortController) abortController.abort()
})
diff --git a/frontend/src/views/user/KeysView.vue b/frontend/src/views/user/KeysView.vue
index 34cccf9c..cf29e4bd 100644
--- a/frontend/src/views/user/KeysView.vue
+++ b/frontend/src/views/user/KeysView.vue
@@ -61,7 +61,7 @@
- {{ maskKey(value) }}
+ {{ maskApiKey(value) }}
{
@@ -1260,11 +1261,6 @@ const filteredGroupOptions = computed(() => {
})
})
-const maskKey = (key: string): string => {
- if (key.length <= 12) return key
- return `${key.slice(0, 8)}...${key.slice(-4)}`
-}
-
const copyToClipboard = async (text: string, keyId: number) => {
const success = await clipboardCopy(text, t('keys.copied'))
if (success) {
--
GitLab
From 8cf83c984e10042eaddcb687cc0702f935c2b663 Mon Sep 17 00:00:00 2001
From: erio
Date: Tue, 21 Apr 2026 10:10:56 +0800
Subject: [PATCH 088/261] feat(channel-monitor): aggregate history to daily
rollups + soft delete
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
明细只保留 1 天,超过 1 天聚合到新表 channel_monitor_daily_rollups(按
monitor_id/model/bucket_date 维度),聚合保留 30 天。两张表都用 SoftDeleteMixin
软删除(DELETE 自动改为 UPDATE deleted_at = NOW())。
聚合 + 清理任务由 OpsCleanupService 的 cron 统一调度,与运维监控的清理共享
schedule(默认 0 2 * * *)和 leader lock。ChannelMonitorRunner 的 cleanupLoop
被移除,只保留 dueCheckLoop。
读取路径 ComputeAvailability* 改为 UNION 明细(今天 deleted_at IS NULL)+
聚合(过去 windowDays 天 deleted_at IS NULL),SUM(ok)/SUM(total) 自然加权
计算可用率,AVG latency 用 SUM(sum_latency_ms)/SUM(count_latency)。
watermark 表 channel_monitor_aggregation_watermark 单行(id=1),记录
last_aggregated_date,重启后从该日期 +1 继续聚合,首次为 nil 则从
today - 30d 开始回填,单次最多 35 天上限避免长事务。
raw SQL 的 ListLatestPerModel / ListLatestForMonitorIDs / ListRecentHistoryForMonitors
都补上 deleted_at IS NULL 过滤(SoftDeleteMixin interceptor 只对 ent query 生效)。
bump version to 0.1.114.28
GroupBadge 在 MonitorKeyPickerDialog 中复用平台主题色 + 倍率/专属倍率
(顺手优化)。
---
backend/cmd/server/wire_gen.go | 2 +-
backend/ent/channelmonitor.go | 18 +-
backend/ent/channelmonitor/channelmonitor.go | 30 +
backend/ent/channelmonitor/where.go | 23 +
backend/ent/channelmonitor_create.go | 32 +
backend/ent/channelmonitor_query.go | 102 +-
backend/ent/channelmonitor_update.go | 163 ++
backend/ent/channelmonitordailyrollup.go | 292 +++
.../channelmonitordailyrollup.go | 222 ++
.../ent/channelmonitordailyrollup/where.go | 784 ++++++
.../ent/channelmonitordailyrollup_create.go | 1593 +++++++++++++
.../ent/channelmonitordailyrollup_delete.go | 88 +
.../ent/channelmonitordailyrollup_query.go | 643 +++++
.../ent/channelmonitordailyrollup_update.go | 1025 ++++++++
backend/ent/channelmonitorhistory.go | 16 +-
.../channelmonitorhistory.go | 16 +
backend/ent/channelmonitorhistory/where.go | 55 +
backend/ent/channelmonitorhistory_create.go | 94 +-
backend/ent/channelmonitorhistory_query.go | 8 +-
backend/ent/channelmonitorhistory_update.go | 52 +
backend/ent/client.go | 369 ++-
backend/ent/ent.go | 66 +-
backend/ent/hook/hook.go | 12 +
backend/ent/intercept/intercept.go | 30 +
backend/ent/migrate/schema.go | 57 +-
backend/ent/mutation.go | 2099 +++++++++++++++--
backend/ent/predicate/predicate.go | 3 +
backend/ent/runtime/runtime.go | 77 +
backend/ent/schema/channel_monitor.go | 2 +
.../schema/channel_monitor_daily_rollup.go | 73 +
backend/ent/schema/channel_monitor_history.go | 11 +-
backend/ent/tx.go | 3 +
.../repository/channel_monitor_repo.go | 201 +-
.../internal/service/channel_monitor_const.go | 21 +-
.../service/channel_monitor_runner.go | 57 +-
.../service/channel_monitor_service.go | 107 +-
.../internal/service/ops_cleanup_service.go | 33 +-
backend/internal/service/wire.go | 5 +-
.../126_add_channel_monitor_aggregation.sql | 60 +
.../admin/monitor/MonitorFormDialog.vue | 9 +-
.../admin/monitor/MonitorKeyPickerDialog.vue | 19 +-
41 files changed, 8088 insertions(+), 484 deletions(-)
create mode 100644 backend/ent/channelmonitordailyrollup.go
create mode 100644 backend/ent/channelmonitordailyrollup/channelmonitordailyrollup.go
create mode 100644 backend/ent/channelmonitordailyrollup/where.go
create mode 100644 backend/ent/channelmonitordailyrollup_create.go
create mode 100644 backend/ent/channelmonitordailyrollup_delete.go
create mode 100644 backend/ent/channelmonitordailyrollup_query.go
create mode 100644 backend/ent/channelmonitordailyrollup_update.go
create mode 100644 backend/ent/schema/channel_monitor_daily_rollup.go
create mode 100644 backend/migrations/126_add_channel_monitor_aggregation.sql
diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go
index 754f814a..a878ea68 100644
--- a/backend/cmd/server/wire_gen.go
+++ b/backend/cmd/server/wire_gen.go
@@ -252,7 +252,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
opsMetricsCollector := service.ProvideOpsMetricsCollector(opsRepository, settingRepository, accountRepository, concurrencyService, db, redisClient, configConfig)
opsAggregationService := service.ProvideOpsAggregationService(opsRepository, settingRepository, db, redisClient, configConfig)
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
- opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
+ opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig, channelMonitorService)
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oAuthRefreshAPI)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
diff --git a/backend/ent/channelmonitor.go b/backend/ent/channelmonitor.go
index 292c2b28..58886884 100644
--- a/backend/ent/channelmonitor.go
+++ b/backend/ent/channelmonitor.go
@@ -54,9 +54,11 @@ type ChannelMonitor struct {
type ChannelMonitorEdges struct {
// History holds the value of the history edge.
History []*ChannelMonitorHistory `json:"history,omitempty"`
+ // DailyRollups holds the value of the daily_rollups edge.
+ DailyRollups []*ChannelMonitorDailyRollup `json:"daily_rollups,omitempty"`
// loadedTypes holds the information for reporting if a
// type was loaded (or requested) in eager-loading or not.
- loadedTypes [1]bool
+ loadedTypes [2]bool
}
// HistoryOrErr returns the History value or an error if the edge
@@ -68,6 +70,15 @@ func (e ChannelMonitorEdges) HistoryOrErr() ([]*ChannelMonitorHistory, error) {
return nil, &NotLoadedError{edge: "history"}
}
+// DailyRollupsOrErr returns the DailyRollups value or an error if the edge
+// was not loaded in eager-loading.
+func (e ChannelMonitorEdges) DailyRollupsOrErr() ([]*ChannelMonitorDailyRollup, error) {
+ if e.loadedTypes[1] {
+ return e.DailyRollups, nil
+ }
+ return nil, &NotLoadedError{edge: "daily_rollups"}
+}
+
// scanValues returns the types for scanning values from sql.Rows.
func (*ChannelMonitor) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
@@ -203,6 +214,11 @@ func (_m *ChannelMonitor) QueryHistory() *ChannelMonitorHistoryQuery {
return NewChannelMonitorClient(_m.config).QueryHistory(_m)
}
+// QueryDailyRollups queries the "daily_rollups" edge of the ChannelMonitor entity.
+func (_m *ChannelMonitor) QueryDailyRollups() *ChannelMonitorDailyRollupQuery {
+ return NewChannelMonitorClient(_m.config).QueryDailyRollups(_m)
+}
+
// Update returns a builder for updating this ChannelMonitor.
// Note that you need to call ChannelMonitor.Unwrap() before calling this method if this ChannelMonitor
// was returned from a transaction, and the transaction was committed or rolled back.
diff --git a/backend/ent/channelmonitor/channelmonitor.go b/backend/ent/channelmonitor/channelmonitor.go
index c5ab8199..ff6d7105 100644
--- a/backend/ent/channelmonitor/channelmonitor.go
+++ b/backend/ent/channelmonitor/channelmonitor.go
@@ -43,6 +43,8 @@ const (
FieldCreatedBy = "created_by"
// EdgeHistory holds the string denoting the history edge name in mutations.
EdgeHistory = "history"
+ // EdgeDailyRollups holds the string denoting the daily_rollups edge name in mutations.
+ EdgeDailyRollups = "daily_rollups"
// Table holds the table name of the channelmonitor in the database.
Table = "channel_monitors"
// HistoryTable is the table that holds the history relation/edge.
@@ -52,6 +54,13 @@ const (
HistoryInverseTable = "channel_monitor_histories"
// HistoryColumn is the table column denoting the history relation/edge.
HistoryColumn = "monitor_id"
+ // DailyRollupsTable is the table that holds the daily_rollups relation/edge.
+ DailyRollupsTable = "channel_monitor_daily_rollups"
+ // DailyRollupsInverseTable is the table name for the ChannelMonitorDailyRollup entity.
+ // It exists in this package in order to avoid circular dependency with the "channelmonitordailyrollup" package.
+ DailyRollupsInverseTable = "channel_monitor_daily_rollups"
+ // DailyRollupsColumn is the table column denoting the daily_rollups relation/edge.
+ DailyRollupsColumn = "monitor_id"
)
// Columns holds all SQL columns for channelmonitor fields.
@@ -214,6 +223,20 @@ func ByHistory(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
sqlgraph.OrderByNeighborTerms(s, newHistoryStep(), append([]sql.OrderTerm{term}, terms...)...)
}
}
+
+// ByDailyRollupsCount orders the results by daily_rollups count.
+func ByDailyRollupsCount(opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborsCount(s, newDailyRollupsStep(), opts...)
+ }
+}
+
+// ByDailyRollups orders the results by daily_rollups terms.
+func ByDailyRollups(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newDailyRollupsStep(), append([]sql.OrderTerm{term}, terms...)...)
+ }
+}
func newHistoryStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
@@ -221,3 +244,10 @@ func newHistoryStep() *sqlgraph.Step {
sqlgraph.Edge(sqlgraph.O2M, false, HistoryTable, HistoryColumn),
)
}
+func newDailyRollupsStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(DailyRollupsInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, DailyRollupsTable, DailyRollupsColumn),
+ )
+}
diff --git a/backend/ent/channelmonitor/where.go b/backend/ent/channelmonitor/where.go
index 8126fb77..abb8484d 100644
--- a/backend/ent/channelmonitor/where.go
+++ b/backend/ent/channelmonitor/where.go
@@ -708,6 +708,29 @@ func HasHistoryWith(preds ...predicate.ChannelMonitorHistory) predicate.ChannelM
})
}
+// HasDailyRollups applies the HasEdge predicate on the "daily_rollups" edge.
+func HasDailyRollups() predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, DailyRollupsTable, DailyRollupsColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasDailyRollupsWith applies the HasEdge predicate on the "daily_rollups" edge with a given conditions (other predicates).
+func HasDailyRollupsWith(preds ...predicate.ChannelMonitorDailyRollup) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(func(s *sql.Selector) {
+ step := newDailyRollupsStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
// And groups predicates with the AND operator between them.
func And(predicates ...predicate.ChannelMonitor) predicate.ChannelMonitor {
return predicate.ChannelMonitor(sql.AndPredicates(predicates...))
diff --git a/backend/ent/channelmonitor_create.go b/backend/ent/channelmonitor_create.go
index ad735f3e..30a7b40d 100644
--- a/backend/ent/channelmonitor_create.go
+++ b/backend/ent/channelmonitor_create.go
@@ -12,6 +12,7 @@ import (
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
"github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
)
@@ -156,6 +157,21 @@ func (_c *ChannelMonitorCreate) AddHistory(v ...*ChannelMonitorHistory) *Channel
return _c.AddHistoryIDs(ids...)
}
+// AddDailyRollupIDs adds the "daily_rollups" edge to the ChannelMonitorDailyRollup entity by IDs.
+func (_c *ChannelMonitorCreate) AddDailyRollupIDs(ids ...int64) *ChannelMonitorCreate {
+ _c.mutation.AddDailyRollupIDs(ids...)
+ return _c
+}
+
+// AddDailyRollups adds the "daily_rollups" edges to the ChannelMonitorDailyRollup entity.
+func (_c *ChannelMonitorCreate) AddDailyRollups(v ...*ChannelMonitorDailyRollup) *ChannelMonitorCreate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _c.AddDailyRollupIDs(ids...)
+}
+
// Mutation returns the ChannelMonitorMutation object of the builder.
func (_c *ChannelMonitorCreate) Mutation() *ChannelMonitorMutation {
return _c.mutation
@@ -378,6 +394,22 @@ func (_c *ChannelMonitorCreate) createSpec() (*ChannelMonitor, *sqlgraph.CreateS
}
_spec.Edges = append(_spec.Edges, edge)
}
+ if nodes := _c.mutation.DailyRollupsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: channelmonitor.DailyRollupsTable,
+ Columns: []string{channelmonitor.DailyRollupsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges = append(_spec.Edges, edge)
+ }
return _node, _spec
}
diff --git a/backend/ent/channelmonitor_query.go b/backend/ent/channelmonitor_query.go
index 6a532587..2ebd95bb 100644
--- a/backend/ent/channelmonitor_query.go
+++ b/backend/ent/channelmonitor_query.go
@@ -14,6 +14,7 @@ import (
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
"github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
"github.com/Wei-Shaw/sub2api/ent/predicate"
)
@@ -21,12 +22,13 @@ import (
// ChannelMonitorQuery is the builder for querying ChannelMonitor entities.
type ChannelMonitorQuery struct {
config
- ctx *QueryContext
- order []channelmonitor.OrderOption
- inters []Interceptor
- predicates []predicate.ChannelMonitor
- withHistory *ChannelMonitorHistoryQuery
- modifiers []func(*sql.Selector)
+ ctx *QueryContext
+ order []channelmonitor.OrderOption
+ inters []Interceptor
+ predicates []predicate.ChannelMonitor
+ withHistory *ChannelMonitorHistoryQuery
+ withDailyRollups *ChannelMonitorDailyRollupQuery
+ modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
@@ -85,6 +87,28 @@ func (_q *ChannelMonitorQuery) QueryHistory() *ChannelMonitorHistoryQuery {
return query
}
+// QueryDailyRollups chains the current query on the "daily_rollups" edge.
+func (_q *ChannelMonitorQuery) QueryDailyRollups() *ChannelMonitorDailyRollupQuery {
+ query := (&ChannelMonitorDailyRollupClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(channelmonitor.Table, channelmonitor.FieldID, selector),
+ sqlgraph.To(channelmonitordailyrollup.Table, channelmonitordailyrollup.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, channelmonitor.DailyRollupsTable, channelmonitor.DailyRollupsColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
// First returns the first ChannelMonitor entity from the query.
// Returns a *NotFoundError when no ChannelMonitor was found.
func (_q *ChannelMonitorQuery) First(ctx context.Context) (*ChannelMonitor, error) {
@@ -272,12 +296,13 @@ func (_q *ChannelMonitorQuery) Clone() *ChannelMonitorQuery {
return nil
}
return &ChannelMonitorQuery{
- config: _q.config,
- ctx: _q.ctx.Clone(),
- order: append([]channelmonitor.OrderOption{}, _q.order...),
- inters: append([]Interceptor{}, _q.inters...),
- predicates: append([]predicate.ChannelMonitor{}, _q.predicates...),
- withHistory: _q.withHistory.Clone(),
+ config: _q.config,
+ ctx: _q.ctx.Clone(),
+ order: append([]channelmonitor.OrderOption{}, _q.order...),
+ inters: append([]Interceptor{}, _q.inters...),
+ predicates: append([]predicate.ChannelMonitor{}, _q.predicates...),
+ withHistory: _q.withHistory.Clone(),
+ withDailyRollups: _q.withDailyRollups.Clone(),
// clone intermediate query.
sql: _q.sql.Clone(),
path: _q.path,
@@ -295,6 +320,17 @@ func (_q *ChannelMonitorQuery) WithHistory(opts ...func(*ChannelMonitorHistoryQu
return _q
}
+// WithDailyRollups tells the query-builder to eager-load the nodes that are connected to
+// the "daily_rollups" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *ChannelMonitorQuery) WithDailyRollups(opts ...func(*ChannelMonitorDailyRollupQuery)) *ChannelMonitorQuery {
+ query := (&ChannelMonitorDailyRollupClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withDailyRollups = query
+ return _q
+}
+
// GroupBy is used to group vertices by one or more fields/columns.
// It is often used with aggregate functions, like: count, max, mean, min, sum.
//
@@ -373,8 +409,9 @@ func (_q *ChannelMonitorQuery) sqlAll(ctx context.Context, hooks ...queryHook) (
var (
nodes = []*ChannelMonitor{}
_spec = _q.querySpec()
- loadedTypes = [1]bool{
+ loadedTypes = [2]bool{
_q.withHistory != nil,
+ _q.withDailyRollups != nil,
}
)
_spec.ScanValues = func(columns []string) ([]any, error) {
@@ -405,6 +442,15 @@ func (_q *ChannelMonitorQuery) sqlAll(ctx context.Context, hooks ...queryHook) (
return nil, err
}
}
+ if query := _q.withDailyRollups; query != nil {
+ if err := _q.loadDailyRollups(ctx, query, nodes,
+ func(n *ChannelMonitor) { n.Edges.DailyRollups = []*ChannelMonitorDailyRollup{} },
+ func(n *ChannelMonitor, e *ChannelMonitorDailyRollup) {
+ n.Edges.DailyRollups = append(n.Edges.DailyRollups, e)
+ }); err != nil {
+ return nil, err
+ }
+ }
return nodes, nil
}
@@ -438,6 +484,36 @@ func (_q *ChannelMonitorQuery) loadHistory(ctx context.Context, query *ChannelMo
}
return nil
}
+func (_q *ChannelMonitorQuery) loadDailyRollups(ctx context.Context, query *ChannelMonitorDailyRollupQuery, nodes []*ChannelMonitor, init func(*ChannelMonitor), assign func(*ChannelMonitor, *ChannelMonitorDailyRollup)) error {
+ fks := make([]driver.Value, 0, len(nodes))
+ nodeids := make(map[int64]*ChannelMonitor)
+ for i := range nodes {
+ fks = append(fks, nodes[i].ID)
+ nodeids[nodes[i].ID] = nodes[i]
+ if init != nil {
+ init(nodes[i])
+ }
+ }
+ if len(query.ctx.Fields) > 0 {
+ query.ctx.AppendFieldOnce(channelmonitordailyrollup.FieldMonitorID)
+ }
+ query.Where(predicate.ChannelMonitorDailyRollup(func(s *sql.Selector) {
+ s.Where(sql.InValues(s.C(channelmonitor.DailyRollupsColumn), fks...))
+ }))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ fk := n.MonitorID
+ node, ok := nodeids[fk]
+ if !ok {
+ return fmt.Errorf(`unexpected referenced foreign-key "monitor_id" returned %v for node %v`, fk, n.ID)
+ }
+ assign(node, n)
+ }
+ return nil
+}
func (_q *ChannelMonitorQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()
diff --git a/backend/ent/channelmonitor_update.go b/backend/ent/channelmonitor_update.go
index df575a9f..7ba4e449 100644
--- a/backend/ent/channelmonitor_update.go
+++ b/backend/ent/channelmonitor_update.go
@@ -13,6 +13,7 @@ import (
"entgo.io/ent/dialect/sql/sqljson"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
"github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
"github.com/Wei-Shaw/sub2api/ent/predicate"
)
@@ -229,6 +230,21 @@ func (_u *ChannelMonitorUpdate) AddHistory(v ...*ChannelMonitorHistory) *Channel
return _u.AddHistoryIDs(ids...)
}
+// AddDailyRollupIDs adds the "daily_rollups" edge to the ChannelMonitorDailyRollup entity by IDs.
+func (_u *ChannelMonitorUpdate) AddDailyRollupIDs(ids ...int64) *ChannelMonitorUpdate {
+ _u.mutation.AddDailyRollupIDs(ids...)
+ return _u
+}
+
+// AddDailyRollups adds the "daily_rollups" edges to the ChannelMonitorDailyRollup entity.
+func (_u *ChannelMonitorUpdate) AddDailyRollups(v ...*ChannelMonitorDailyRollup) *ChannelMonitorUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddDailyRollupIDs(ids...)
+}
+
// Mutation returns the ChannelMonitorMutation object of the builder.
func (_u *ChannelMonitorUpdate) Mutation() *ChannelMonitorMutation {
return _u.mutation
@@ -255,6 +271,27 @@ func (_u *ChannelMonitorUpdate) RemoveHistory(v ...*ChannelMonitorHistory) *Chan
return _u.RemoveHistoryIDs(ids...)
}
+// ClearDailyRollups clears all "daily_rollups" edges to the ChannelMonitorDailyRollup entity.
+func (_u *ChannelMonitorUpdate) ClearDailyRollups() *ChannelMonitorUpdate {
+ _u.mutation.ClearDailyRollups()
+ return _u
+}
+
+// RemoveDailyRollupIDs removes the "daily_rollups" edge to ChannelMonitorDailyRollup entities by IDs.
+func (_u *ChannelMonitorUpdate) RemoveDailyRollupIDs(ids ...int64) *ChannelMonitorUpdate {
+ _u.mutation.RemoveDailyRollupIDs(ids...)
+ return _u
+}
+
+// RemoveDailyRollups removes "daily_rollups" edges to ChannelMonitorDailyRollup entities.
+func (_u *ChannelMonitorUpdate) RemoveDailyRollups(v ...*ChannelMonitorDailyRollup) *ChannelMonitorUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveDailyRollupIDs(ids...)
+}
+
// Save executes the query and returns the number of nodes affected by the update operation.
func (_u *ChannelMonitorUpdate) Save(ctx context.Context) (int, error) {
_u.defaults()
@@ -441,6 +478,51 @@ func (_u *ChannelMonitorUpdate) sqlSave(ctx context.Context) (_node int, err err
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
+ if _u.mutation.DailyRollupsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: channelmonitor.DailyRollupsTable,
+ Columns: []string{channelmonitor.DailyRollupsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedDailyRollupsIDs(); len(nodes) > 0 && !_u.mutation.DailyRollupsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: channelmonitor.DailyRollupsTable,
+ Columns: []string{channelmonitor.DailyRollupsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.DailyRollupsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: channelmonitor.DailyRollupsTable,
+ Columns: []string{channelmonitor.DailyRollupsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{channelmonitor.Label}
@@ -660,6 +742,21 @@ func (_u *ChannelMonitorUpdateOne) AddHistory(v ...*ChannelMonitorHistory) *Chan
return _u.AddHistoryIDs(ids...)
}
+// AddDailyRollupIDs adds the "daily_rollups" edge to the ChannelMonitorDailyRollup entity by IDs.
+func (_u *ChannelMonitorUpdateOne) AddDailyRollupIDs(ids ...int64) *ChannelMonitorUpdateOne {
+ _u.mutation.AddDailyRollupIDs(ids...)
+ return _u
+}
+
+// AddDailyRollups adds the "daily_rollups" edges to the ChannelMonitorDailyRollup entity.
+func (_u *ChannelMonitorUpdateOne) AddDailyRollups(v ...*ChannelMonitorDailyRollup) *ChannelMonitorUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddDailyRollupIDs(ids...)
+}
+
// Mutation returns the ChannelMonitorMutation object of the builder.
func (_u *ChannelMonitorUpdateOne) Mutation() *ChannelMonitorMutation {
return _u.mutation
@@ -686,6 +783,27 @@ func (_u *ChannelMonitorUpdateOne) RemoveHistory(v ...*ChannelMonitorHistory) *C
return _u.RemoveHistoryIDs(ids...)
}
+// ClearDailyRollups clears all "daily_rollups" edges to the ChannelMonitorDailyRollup entity.
+func (_u *ChannelMonitorUpdateOne) ClearDailyRollups() *ChannelMonitorUpdateOne {
+ _u.mutation.ClearDailyRollups()
+ return _u
+}
+
+// RemoveDailyRollupIDs removes the "daily_rollups" edge to ChannelMonitorDailyRollup entities by IDs.
+func (_u *ChannelMonitorUpdateOne) RemoveDailyRollupIDs(ids ...int64) *ChannelMonitorUpdateOne {
+ _u.mutation.RemoveDailyRollupIDs(ids...)
+ return _u
+}
+
+// RemoveDailyRollups removes "daily_rollups" edges to ChannelMonitorDailyRollup entities.
+func (_u *ChannelMonitorUpdateOne) RemoveDailyRollups(v ...*ChannelMonitorDailyRollup) *ChannelMonitorUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveDailyRollupIDs(ids...)
+}
+
// Where appends a list predicates to the ChannelMonitorUpdate builder.
func (_u *ChannelMonitorUpdateOne) Where(ps ...predicate.ChannelMonitor) *ChannelMonitorUpdateOne {
_u.mutation.Where(ps...)
@@ -902,6 +1020,51 @@ func (_u *ChannelMonitorUpdateOne) sqlSave(ctx context.Context) (_node *ChannelM
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
+ if _u.mutation.DailyRollupsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: channelmonitor.DailyRollupsTable,
+ Columns: []string{channelmonitor.DailyRollupsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedDailyRollupsIDs(); len(nodes) > 0 && !_u.mutation.DailyRollupsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: channelmonitor.DailyRollupsTable,
+ Columns: []string{channelmonitor.DailyRollupsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.DailyRollupsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: channelmonitor.DailyRollupsTable,
+ Columns: []string{channelmonitor.DailyRollupsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
_node = &ChannelMonitor{config: _u.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues
diff --git a/backend/ent/channelmonitordailyrollup.go b/backend/ent/channelmonitordailyrollup.go
new file mode 100644
index 00000000..6c7a8afa
--- /dev/null
+++ b/backend/ent/channelmonitordailyrollup.go
@@ -0,0 +1,292 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "fmt"
+ "strings"
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/sql"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
+)
+
+// ChannelMonitorDailyRollup is the model entity for the ChannelMonitorDailyRollup schema.
+type ChannelMonitorDailyRollup struct {
+ config `json:"-"`
+ // ID of the ent.
+ ID int64 `json:"id,omitempty"`
+ // DeletedAt holds the value of the "deleted_at" field.
+ DeletedAt *time.Time `json:"deleted_at,omitempty"`
+ // MonitorID holds the value of the "monitor_id" field.
+ MonitorID int64 `json:"monitor_id,omitempty"`
+ // Model holds the value of the "model" field.
+ Model string `json:"model,omitempty"`
+ // BucketDate holds the value of the "bucket_date" field.
+ BucketDate time.Time `json:"bucket_date,omitempty"`
+ // TotalChecks holds the value of the "total_checks" field.
+ TotalChecks int `json:"total_checks,omitempty"`
+ // OkCount holds the value of the "ok_count" field.
+ OkCount int `json:"ok_count,omitempty"`
+ // OperationalCount holds the value of the "operational_count" field.
+ OperationalCount int `json:"operational_count,omitempty"`
+ // DegradedCount holds the value of the "degraded_count" field.
+ DegradedCount int `json:"degraded_count,omitempty"`
+ // FailedCount holds the value of the "failed_count" field.
+ FailedCount int `json:"failed_count,omitempty"`
+ // ErrorCount holds the value of the "error_count" field.
+ ErrorCount int `json:"error_count,omitempty"`
+ // SumLatencyMs holds the value of the "sum_latency_ms" field.
+ SumLatencyMs int64 `json:"sum_latency_ms,omitempty"`
+ // CountLatency holds the value of the "count_latency" field.
+ CountLatency int `json:"count_latency,omitempty"`
+ // SumPingLatencyMs holds the value of the "sum_ping_latency_ms" field.
+ SumPingLatencyMs int64 `json:"sum_ping_latency_ms,omitempty"`
+ // CountPingLatency holds the value of the "count_ping_latency" field.
+ CountPingLatency int `json:"count_ping_latency,omitempty"`
+ // ComputedAt holds the value of the "computed_at" field.
+ ComputedAt time.Time `json:"computed_at,omitempty"`
+ // Edges holds the relations/edges for other nodes in the graph.
+ // The values are being populated by the ChannelMonitorDailyRollupQuery when eager-loading is set.
+ Edges ChannelMonitorDailyRollupEdges `json:"edges"`
+ selectValues sql.SelectValues
+}
+
+// ChannelMonitorDailyRollupEdges holds the relations/edges for other nodes in the graph.
+type ChannelMonitorDailyRollupEdges struct {
+ // Monitor holds the value of the monitor edge.
+ Monitor *ChannelMonitor `json:"monitor,omitempty"`
+ // loadedTypes holds the information for reporting if a
+ // type was loaded (or requested) in eager-loading or not.
+ loadedTypes [1]bool
+}
+
+// MonitorOrErr returns the Monitor value or an error if the edge
+// was not loaded in eager-loading, or loaded but was not found.
+func (e ChannelMonitorDailyRollupEdges) MonitorOrErr() (*ChannelMonitor, error) {
+ if e.Monitor != nil {
+ return e.Monitor, nil
+ } else if e.loadedTypes[0] {
+ return nil, &NotFoundError{label: channelmonitor.Label}
+ }
+ return nil, &NotLoadedError{edge: "monitor"}
+}
+
+// scanValues returns the types for scanning values from sql.Rows.
+func (*ChannelMonitorDailyRollup) scanValues(columns []string) ([]any, error) {
+ values := make([]any, len(columns))
+ for i := range columns {
+ switch columns[i] {
+ case channelmonitordailyrollup.FieldID, channelmonitordailyrollup.FieldMonitorID, channelmonitordailyrollup.FieldTotalChecks, channelmonitordailyrollup.FieldOkCount, channelmonitordailyrollup.FieldOperationalCount, channelmonitordailyrollup.FieldDegradedCount, channelmonitordailyrollup.FieldFailedCount, channelmonitordailyrollup.FieldErrorCount, channelmonitordailyrollup.FieldSumLatencyMs, channelmonitordailyrollup.FieldCountLatency, channelmonitordailyrollup.FieldSumPingLatencyMs, channelmonitordailyrollup.FieldCountPingLatency:
+ values[i] = new(sql.NullInt64)
+ case channelmonitordailyrollup.FieldModel:
+ values[i] = new(sql.NullString)
+ case channelmonitordailyrollup.FieldDeletedAt, channelmonitordailyrollup.FieldBucketDate, channelmonitordailyrollup.FieldComputedAt:
+ values[i] = new(sql.NullTime)
+ default:
+ values[i] = new(sql.UnknownType)
+ }
+ }
+ return values, nil
+}
+
+// assignValues assigns the values that were returned from sql.Rows (after scanning)
+// to the ChannelMonitorDailyRollup fields.
+func (_m *ChannelMonitorDailyRollup) assignValues(columns []string, values []any) error {
+ if m, n := len(values), len(columns); m < n {
+ return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
+ }
+ for i := range columns {
+ switch columns[i] {
+ case channelmonitordailyrollup.FieldID:
+ value, ok := values[i].(*sql.NullInt64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field id", value)
+ }
+ _m.ID = int64(value.Int64)
+ case channelmonitordailyrollup.FieldDeletedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field deleted_at", values[i])
+ } else if value.Valid {
+ _m.DeletedAt = new(time.Time)
+ *_m.DeletedAt = value.Time
+ }
+ case channelmonitordailyrollup.FieldMonitorID:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field monitor_id", values[i])
+ } else if value.Valid {
+ _m.MonitorID = value.Int64
+ }
+ case channelmonitordailyrollup.FieldModel:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field model", values[i])
+ } else if value.Valid {
+ _m.Model = value.String
+ }
+ case channelmonitordailyrollup.FieldBucketDate:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field bucket_date", values[i])
+ } else if value.Valid {
+ _m.BucketDate = value.Time
+ }
+ case channelmonitordailyrollup.FieldTotalChecks:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field total_checks", values[i])
+ } else if value.Valid {
+ _m.TotalChecks = int(value.Int64)
+ }
+ case channelmonitordailyrollup.FieldOkCount:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field ok_count", values[i])
+ } else if value.Valid {
+ _m.OkCount = int(value.Int64)
+ }
+ case channelmonitordailyrollup.FieldOperationalCount:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field operational_count", values[i])
+ } else if value.Valid {
+ _m.OperationalCount = int(value.Int64)
+ }
+ case channelmonitordailyrollup.FieldDegradedCount:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field degraded_count", values[i])
+ } else if value.Valid {
+ _m.DegradedCount = int(value.Int64)
+ }
+ case channelmonitordailyrollup.FieldFailedCount:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field failed_count", values[i])
+ } else if value.Valid {
+ _m.FailedCount = int(value.Int64)
+ }
+ case channelmonitordailyrollup.FieldErrorCount:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field error_count", values[i])
+ } else if value.Valid {
+ _m.ErrorCount = int(value.Int64)
+ }
+ case channelmonitordailyrollup.FieldSumLatencyMs:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field sum_latency_ms", values[i])
+ } else if value.Valid {
+ _m.SumLatencyMs = value.Int64
+ }
+ case channelmonitordailyrollup.FieldCountLatency:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field count_latency", values[i])
+ } else if value.Valid {
+ _m.CountLatency = int(value.Int64)
+ }
+ case channelmonitordailyrollup.FieldSumPingLatencyMs:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field sum_ping_latency_ms", values[i])
+ } else if value.Valid {
+ _m.SumPingLatencyMs = value.Int64
+ }
+ case channelmonitordailyrollup.FieldCountPingLatency:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field count_ping_latency", values[i])
+ } else if value.Valid {
+ _m.CountPingLatency = int(value.Int64)
+ }
+ case channelmonitordailyrollup.FieldComputedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field computed_at", values[i])
+ } else if value.Valid {
+ _m.ComputedAt = value.Time
+ }
+ default:
+ _m.selectValues.Set(columns[i], values[i])
+ }
+ }
+ return nil
+}
+
+// Value returns the ent.Value that was dynamically selected and assigned to the ChannelMonitorDailyRollup.
+// This includes values selected through modifiers, order, etc.
+func (_m *ChannelMonitorDailyRollup) Value(name string) (ent.Value, error) {
+ return _m.selectValues.Get(name)
+}
+
+// QueryMonitor queries the "monitor" edge of the ChannelMonitorDailyRollup entity.
+func (_m *ChannelMonitorDailyRollup) QueryMonitor() *ChannelMonitorQuery {
+ return NewChannelMonitorDailyRollupClient(_m.config).QueryMonitor(_m)
+}
+
+// Update returns a builder for updating this ChannelMonitorDailyRollup.
+// Note that you need to call ChannelMonitorDailyRollup.Unwrap() before calling this method if this ChannelMonitorDailyRollup
+// was returned from a transaction, and the transaction was committed or rolled back.
+func (_m *ChannelMonitorDailyRollup) Update() *ChannelMonitorDailyRollupUpdateOne {
+ return NewChannelMonitorDailyRollupClient(_m.config).UpdateOne(_m)
+}
+
+// Unwrap unwraps the ChannelMonitorDailyRollup entity that was returned from a transaction after it was closed,
+// so that all future queries will be executed through the driver which created the transaction.
+func (_m *ChannelMonitorDailyRollup) Unwrap() *ChannelMonitorDailyRollup {
+ _tx, ok := _m.config.driver.(*txDriver)
+ if !ok {
+ panic("ent: ChannelMonitorDailyRollup is not a transactional entity")
+ }
+ _m.config.driver = _tx.drv
+ return _m
+}
+
+// String implements the fmt.Stringer.
+func (_m *ChannelMonitorDailyRollup) String() string {
+ var builder strings.Builder
+ builder.WriteString("ChannelMonitorDailyRollup(")
+ builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
+ if v := _m.DeletedAt; v != nil {
+ builder.WriteString("deleted_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
+ builder.WriteString("monitor_id=")
+ builder.WriteString(fmt.Sprintf("%v", _m.MonitorID))
+ builder.WriteString(", ")
+ builder.WriteString("model=")
+ builder.WriteString(_m.Model)
+ builder.WriteString(", ")
+ builder.WriteString("bucket_date=")
+ builder.WriteString(_m.BucketDate.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("total_checks=")
+ builder.WriteString(fmt.Sprintf("%v", _m.TotalChecks))
+ builder.WriteString(", ")
+ builder.WriteString("ok_count=")
+ builder.WriteString(fmt.Sprintf("%v", _m.OkCount))
+ builder.WriteString(", ")
+ builder.WriteString("operational_count=")
+ builder.WriteString(fmt.Sprintf("%v", _m.OperationalCount))
+ builder.WriteString(", ")
+ builder.WriteString("degraded_count=")
+ builder.WriteString(fmt.Sprintf("%v", _m.DegradedCount))
+ builder.WriteString(", ")
+ builder.WriteString("failed_count=")
+ builder.WriteString(fmt.Sprintf("%v", _m.FailedCount))
+ builder.WriteString(", ")
+ builder.WriteString("error_count=")
+ builder.WriteString(fmt.Sprintf("%v", _m.ErrorCount))
+ builder.WriteString(", ")
+ builder.WriteString("sum_latency_ms=")
+ builder.WriteString(fmt.Sprintf("%v", _m.SumLatencyMs))
+ builder.WriteString(", ")
+ builder.WriteString("count_latency=")
+ builder.WriteString(fmt.Sprintf("%v", _m.CountLatency))
+ builder.WriteString(", ")
+ builder.WriteString("sum_ping_latency_ms=")
+ builder.WriteString(fmt.Sprintf("%v", _m.SumPingLatencyMs))
+ builder.WriteString(", ")
+ builder.WriteString("count_ping_latency=")
+ builder.WriteString(fmt.Sprintf("%v", _m.CountPingLatency))
+ builder.WriteString(", ")
+ builder.WriteString("computed_at=")
+ builder.WriteString(_m.ComputedAt.Format(time.ANSIC))
+ builder.WriteByte(')')
+ return builder.String()
+}
+
+// ChannelMonitorDailyRollups is a parsable slice of ChannelMonitorDailyRollup.
+type ChannelMonitorDailyRollups []*ChannelMonitorDailyRollup
diff --git a/backend/ent/channelmonitordailyrollup/channelmonitordailyrollup.go b/backend/ent/channelmonitordailyrollup/channelmonitordailyrollup.go
new file mode 100644
index 00000000..eb1f69a8
--- /dev/null
+++ b/backend/ent/channelmonitordailyrollup/channelmonitordailyrollup.go
@@ -0,0 +1,222 @@
+// Code generated by ent, DO NOT EDIT.
+
+package channelmonitordailyrollup
+
+import (
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+)
+
+const (
+ // Label holds the string label denoting the channelmonitordailyrollup type in the database.
+ Label = "channel_monitor_daily_rollup"
+ // FieldID holds the string denoting the id field in the database.
+ FieldID = "id"
+ // FieldDeletedAt holds the string denoting the deleted_at field in the database.
+ FieldDeletedAt = "deleted_at"
+ // FieldMonitorID holds the string denoting the monitor_id field in the database.
+ FieldMonitorID = "monitor_id"
+ // FieldModel holds the string denoting the model field in the database.
+ FieldModel = "model"
+ // FieldBucketDate holds the string denoting the bucket_date field in the database.
+ FieldBucketDate = "bucket_date"
+ // FieldTotalChecks holds the string denoting the total_checks field in the database.
+ FieldTotalChecks = "total_checks"
+ // FieldOkCount holds the string denoting the ok_count field in the database.
+ FieldOkCount = "ok_count"
+ // FieldOperationalCount holds the string denoting the operational_count field in the database.
+ FieldOperationalCount = "operational_count"
+ // FieldDegradedCount holds the string denoting the degraded_count field in the database.
+ FieldDegradedCount = "degraded_count"
+ // FieldFailedCount holds the string denoting the failed_count field in the database.
+ FieldFailedCount = "failed_count"
+ // FieldErrorCount holds the string denoting the error_count field in the database.
+ FieldErrorCount = "error_count"
+ // FieldSumLatencyMs holds the string denoting the sum_latency_ms field in the database.
+ FieldSumLatencyMs = "sum_latency_ms"
+ // FieldCountLatency holds the string denoting the count_latency field in the database.
+ FieldCountLatency = "count_latency"
+ // FieldSumPingLatencyMs holds the string denoting the sum_ping_latency_ms field in the database.
+ FieldSumPingLatencyMs = "sum_ping_latency_ms"
+ // FieldCountPingLatency holds the string denoting the count_ping_latency field in the database.
+ FieldCountPingLatency = "count_ping_latency"
+ // FieldComputedAt holds the string denoting the computed_at field in the database.
+ FieldComputedAt = "computed_at"
+ // EdgeMonitor holds the string denoting the monitor edge name in mutations.
+ EdgeMonitor = "monitor"
+ // Table holds the table name of the channelmonitordailyrollup in the database.
+ Table = "channel_monitor_daily_rollups"
+ // MonitorTable is the table that holds the monitor relation/edge.
+ MonitorTable = "channel_monitor_daily_rollups"
+ // MonitorInverseTable is the table name for the ChannelMonitor entity.
+ // It exists in this package in order to avoid circular dependency with the "channelmonitor" package.
+ MonitorInverseTable = "channel_monitors"
+ // MonitorColumn is the table column denoting the monitor relation/edge.
+ MonitorColumn = "monitor_id"
+)
+
+// Columns holds all SQL columns for channelmonitordailyrollup fields.
+var Columns = []string{
+ FieldID,
+ FieldDeletedAt,
+ FieldMonitorID,
+ FieldModel,
+ FieldBucketDate,
+ FieldTotalChecks,
+ FieldOkCount,
+ FieldOperationalCount,
+ FieldDegradedCount,
+ FieldFailedCount,
+ FieldErrorCount,
+ FieldSumLatencyMs,
+ FieldCountLatency,
+ FieldSumPingLatencyMs,
+ FieldCountPingLatency,
+ FieldComputedAt,
+}
+
+// ValidColumn reports if the column name is valid (part of the table columns).
+func ValidColumn(column string) bool {
+ for i := range Columns {
+ if column == Columns[i] {
+ return true
+ }
+ }
+ return false
+}
+
+// Note that the variables below are initialized by the runtime
+// package on the initialization of the application. Therefore,
+// it should be imported in the main as follows:
+//
+// import _ "github.com/Wei-Shaw/sub2api/ent/runtime"
+var (
+ Hooks [1]ent.Hook
+ Interceptors [1]ent.Interceptor
+ // ModelValidator is a validator for the "model" field. It is called by the builders before save.
+ ModelValidator func(string) error
+ // DefaultTotalChecks holds the default value on creation for the "total_checks" field.
+ DefaultTotalChecks int
+ // DefaultOkCount holds the default value on creation for the "ok_count" field.
+ DefaultOkCount int
+ // DefaultOperationalCount holds the default value on creation for the "operational_count" field.
+ DefaultOperationalCount int
+ // DefaultDegradedCount holds the default value on creation for the "degraded_count" field.
+ DefaultDegradedCount int
+ // DefaultFailedCount holds the default value on creation for the "failed_count" field.
+ DefaultFailedCount int
+ // DefaultErrorCount holds the default value on creation for the "error_count" field.
+ DefaultErrorCount int
+ // DefaultSumLatencyMs holds the default value on creation for the "sum_latency_ms" field.
+ DefaultSumLatencyMs int64
+ // DefaultCountLatency holds the default value on creation for the "count_latency" field.
+ DefaultCountLatency int
+ // DefaultSumPingLatencyMs holds the default value on creation for the "sum_ping_latency_ms" field.
+ DefaultSumPingLatencyMs int64
+ // DefaultCountPingLatency holds the default value on creation for the "count_ping_latency" field.
+ DefaultCountPingLatency int
+ // DefaultComputedAt holds the default value on creation for the "computed_at" field.
+ DefaultComputedAt func() time.Time
+ // UpdateDefaultComputedAt holds the default value on update for the "computed_at" field.
+ UpdateDefaultComputedAt func() time.Time
+)
+
+// OrderOption defines the ordering options for the ChannelMonitorDailyRollup queries.
+type OrderOption func(*sql.Selector)
+
+// ByID orders the results by the id field.
+func ByID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldID, opts...).ToFunc()
+}
+
+// ByDeletedAt orders the results by the deleted_at field.
+func ByDeletedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldDeletedAt, opts...).ToFunc()
+}
+
+// ByMonitorID orders the results by the monitor_id field.
+func ByMonitorID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldMonitorID, opts...).ToFunc()
+}
+
+// ByModel orders the results by the model field.
+func ByModel(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldModel, opts...).ToFunc()
+}
+
+// ByBucketDate orders the results by the bucket_date field.
+func ByBucketDate(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldBucketDate, opts...).ToFunc()
+}
+
+// ByTotalChecks orders the results by the total_checks field.
+func ByTotalChecks(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldTotalChecks, opts...).ToFunc()
+}
+
+// ByOkCount orders the results by the ok_count field.
+func ByOkCount(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldOkCount, opts...).ToFunc()
+}
+
+// ByOperationalCount orders the results by the operational_count field.
+func ByOperationalCount(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldOperationalCount, opts...).ToFunc()
+}
+
+// ByDegradedCount orders the results by the degraded_count field.
+func ByDegradedCount(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldDegradedCount, opts...).ToFunc()
+}
+
+// ByFailedCount orders the results by the failed_count field.
+func ByFailedCount(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldFailedCount, opts...).ToFunc()
+}
+
+// ByErrorCount orders the results by the error_count field.
+func ByErrorCount(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldErrorCount, opts...).ToFunc()
+}
+
+// BySumLatencyMs orders the results by the sum_latency_ms field.
+func BySumLatencyMs(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldSumLatencyMs, opts...).ToFunc()
+}
+
+// ByCountLatency orders the results by the count_latency field.
+func ByCountLatency(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCountLatency, opts...).ToFunc()
+}
+
+// BySumPingLatencyMs orders the results by the sum_ping_latency_ms field.
+func BySumPingLatencyMs(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldSumPingLatencyMs, opts...).ToFunc()
+}
+
+// ByCountPingLatency orders the results by the count_ping_latency field.
+func ByCountPingLatency(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCountPingLatency, opts...).ToFunc()
+}
+
+// ByComputedAt orders the results by the computed_at field.
+func ByComputedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldComputedAt, opts...).ToFunc()
+}
+
+// ByMonitorField orders the results by monitor field.
+func ByMonitorField(field string, opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newMonitorStep(), sql.OrderByField(field, opts...))
+ }
+}
+func newMonitorStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(MonitorInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, MonitorTable, MonitorColumn),
+ )
+}
diff --git a/backend/ent/channelmonitordailyrollup/where.go b/backend/ent/channelmonitordailyrollup/where.go
new file mode 100644
index 00000000..9da8d4be
--- /dev/null
+++ b/backend/ent/channelmonitordailyrollup/where.go
@@ -0,0 +1,784 @@
+// Code generated by ent, DO NOT EDIT.
+
+package channelmonitordailyrollup
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ID filters vertices based on their ID field.
+func ID(id int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldID, id))
+}
+
+// IDEQ applies the EQ predicate on the ID field.
+func IDEQ(id int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldID, id))
+}
+
+// IDNEQ applies the NEQ predicate on the ID field.
+func IDNEQ(id int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldID, id))
+}
+
+// IDIn applies the In predicate on the ID field.
+func IDIn(ids ...int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldID, ids...))
+}
+
+// IDNotIn applies the NotIn predicate on the ID field.
+func IDNotIn(ids ...int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldID, ids...))
+}
+
+// IDGT applies the GT predicate on the ID field.
+func IDGT(id int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldID, id))
+}
+
+// IDGTE applies the GTE predicate on the ID field.
+func IDGTE(id int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldID, id))
+}
+
+// IDLT applies the LT predicate on the ID field.
+func IDLT(id int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldID, id))
+}
+
+// IDLTE applies the LTE predicate on the ID field.
+func IDLTE(id int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldID, id))
+}
+
+// DeletedAt applies equality check predicate on the "deleted_at" field. It's identical to DeletedAtEQ.
+func DeletedAt(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldDeletedAt, v))
+}
+
+// MonitorID applies equality check predicate on the "monitor_id" field. It's identical to MonitorIDEQ.
+func MonitorID(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldMonitorID, v))
+}
+
+// Model applies equality check predicate on the "model" field. It's identical to ModelEQ.
+func Model(v string) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldModel, v))
+}
+
+// BucketDate applies equality check predicate on the "bucket_date" field. It's identical to BucketDateEQ.
+func BucketDate(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldBucketDate, v))
+}
+
+// TotalChecks applies equality check predicate on the "total_checks" field. It's identical to TotalChecksEQ.
+func TotalChecks(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldTotalChecks, v))
+}
+
+// OkCount applies equality check predicate on the "ok_count" field. It's identical to OkCountEQ.
+func OkCount(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldOkCount, v))
+}
+
+// OperationalCount applies equality check predicate on the "operational_count" field. It's identical to OperationalCountEQ.
+func OperationalCount(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldOperationalCount, v))
+}
+
+// DegradedCount applies equality check predicate on the "degraded_count" field. It's identical to DegradedCountEQ.
+func DegradedCount(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldDegradedCount, v))
+}
+
+// FailedCount applies equality check predicate on the "failed_count" field. It's identical to FailedCountEQ.
+func FailedCount(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldFailedCount, v))
+}
+
+// ErrorCount applies equality check predicate on the "error_count" field. It's identical to ErrorCountEQ.
+func ErrorCount(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldErrorCount, v))
+}
+
+// SumLatencyMs applies equality check predicate on the "sum_latency_ms" field. It's identical to SumLatencyMsEQ.
+func SumLatencyMs(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldSumLatencyMs, v))
+}
+
+// CountLatency applies equality check predicate on the "count_latency" field. It's identical to CountLatencyEQ.
+func CountLatency(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldCountLatency, v))
+}
+
+// SumPingLatencyMs applies equality check predicate on the "sum_ping_latency_ms" field. It's identical to SumPingLatencyMsEQ.
+func SumPingLatencyMs(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldSumPingLatencyMs, v))
+}
+
+// CountPingLatency applies equality check predicate on the "count_ping_latency" field. It's identical to CountPingLatencyEQ.
+func CountPingLatency(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldCountPingLatency, v))
+}
+
+// ComputedAt applies equality check predicate on the "computed_at" field. It's identical to ComputedAtEQ.
+func ComputedAt(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldComputedAt, v))
+}
+
+// DeletedAtEQ applies the EQ predicate on the "deleted_at" field.
+func DeletedAtEQ(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldDeletedAt, v))
+}
+
+// DeletedAtNEQ applies the NEQ predicate on the "deleted_at" field.
+func DeletedAtNEQ(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldDeletedAt, v))
+}
+
+// DeletedAtIn applies the In predicate on the "deleted_at" field.
+func DeletedAtIn(vs ...time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldDeletedAt, vs...))
+}
+
+// DeletedAtNotIn applies the NotIn predicate on the "deleted_at" field.
+func DeletedAtNotIn(vs ...time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldDeletedAt, vs...))
+}
+
+// DeletedAtGT applies the GT predicate on the "deleted_at" field.
+func DeletedAtGT(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldDeletedAt, v))
+}
+
+// DeletedAtGTE applies the GTE predicate on the "deleted_at" field.
+func DeletedAtGTE(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldDeletedAt, v))
+}
+
+// DeletedAtLT applies the LT predicate on the "deleted_at" field.
+func DeletedAtLT(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldDeletedAt, v))
+}
+
+// DeletedAtLTE applies the LTE predicate on the "deleted_at" field.
+func DeletedAtLTE(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldDeletedAt, v))
+}
+
+// DeletedAtIsNil applies the IsNil predicate on the "deleted_at" field.
+func DeletedAtIsNil() predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldIsNull(FieldDeletedAt))
+}
+
+// DeletedAtNotNil applies the NotNil predicate on the "deleted_at" field.
+func DeletedAtNotNil() predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNotNull(FieldDeletedAt))
+}
+
+// MonitorIDEQ applies the EQ predicate on the "monitor_id" field.
+func MonitorIDEQ(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldMonitorID, v))
+}
+
+// MonitorIDNEQ applies the NEQ predicate on the "monitor_id" field.
+func MonitorIDNEQ(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldMonitorID, v))
+}
+
+// MonitorIDIn applies the In predicate on the "monitor_id" field.
+func MonitorIDIn(vs ...int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldMonitorID, vs...))
+}
+
+// MonitorIDNotIn applies the NotIn predicate on the "monitor_id" field.
+func MonitorIDNotIn(vs ...int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldMonitorID, vs...))
+}
+
+// ModelEQ applies the EQ predicate on the "model" field.
+func ModelEQ(v string) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldModel, v))
+}
+
+// ModelNEQ applies the NEQ predicate on the "model" field.
+func ModelNEQ(v string) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldModel, v))
+}
+
+// ModelIn applies the In predicate on the "model" field.
+func ModelIn(vs ...string) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldModel, vs...))
+}
+
+// ModelNotIn applies the NotIn predicate on the "model" field.
+func ModelNotIn(vs ...string) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldModel, vs...))
+}
+
+// ModelGT applies the GT predicate on the "model" field.
+func ModelGT(v string) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldModel, v))
+}
+
+// ModelGTE applies the GTE predicate on the "model" field.
+func ModelGTE(v string) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldModel, v))
+}
+
+// ModelLT applies the LT predicate on the "model" field.
+func ModelLT(v string) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldModel, v))
+}
+
+// ModelLTE applies the LTE predicate on the "model" field.
+func ModelLTE(v string) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldModel, v))
+}
+
+// ModelContains applies the Contains predicate on the "model" field.
+func ModelContains(v string) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldContains(FieldModel, v))
+}
+
+// ModelHasPrefix applies the HasPrefix predicate on the "model" field.
+func ModelHasPrefix(v string) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldHasPrefix(FieldModel, v))
+}
+
+// ModelHasSuffix applies the HasSuffix predicate on the "model" field.
+func ModelHasSuffix(v string) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldHasSuffix(FieldModel, v))
+}
+
+// ModelEqualFold applies the EqualFold predicate on the "model" field.
+func ModelEqualFold(v string) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEqualFold(FieldModel, v))
+}
+
+// ModelContainsFold applies the ContainsFold predicate on the "model" field.
+func ModelContainsFold(v string) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldContainsFold(FieldModel, v))
+}
+
+// BucketDateEQ applies the EQ predicate on the "bucket_date" field.
+func BucketDateEQ(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldBucketDate, v))
+}
+
+// BucketDateNEQ applies the NEQ predicate on the "bucket_date" field.
+func BucketDateNEQ(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldBucketDate, v))
+}
+
+// BucketDateIn applies the In predicate on the "bucket_date" field.
+func BucketDateIn(vs ...time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldBucketDate, vs...))
+}
+
+// BucketDateNotIn applies the NotIn predicate on the "bucket_date" field.
+func BucketDateNotIn(vs ...time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldBucketDate, vs...))
+}
+
+// BucketDateGT applies the GT predicate on the "bucket_date" field.
+func BucketDateGT(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldBucketDate, v))
+}
+
+// BucketDateGTE applies the GTE predicate on the "bucket_date" field.
+func BucketDateGTE(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldBucketDate, v))
+}
+
+// BucketDateLT applies the LT predicate on the "bucket_date" field.
+func BucketDateLT(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldBucketDate, v))
+}
+
+// BucketDateLTE applies the LTE predicate on the "bucket_date" field.
+func BucketDateLTE(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldBucketDate, v))
+}
+
+// TotalChecksEQ applies the EQ predicate on the "total_checks" field.
+func TotalChecksEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldTotalChecks, v))
+}
+
+// TotalChecksNEQ applies the NEQ predicate on the "total_checks" field.
+func TotalChecksNEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldTotalChecks, v))
+}
+
+// TotalChecksIn applies the In predicate on the "total_checks" field.
+func TotalChecksIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldTotalChecks, vs...))
+}
+
+// TotalChecksNotIn applies the NotIn predicate on the "total_checks" field.
+func TotalChecksNotIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldTotalChecks, vs...))
+}
+
+// TotalChecksGT applies the GT predicate on the "total_checks" field.
+func TotalChecksGT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldTotalChecks, v))
+}
+
+// TotalChecksGTE applies the GTE predicate on the "total_checks" field.
+func TotalChecksGTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldTotalChecks, v))
+}
+
+// TotalChecksLT applies the LT predicate on the "total_checks" field.
+func TotalChecksLT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldTotalChecks, v))
+}
+
+// TotalChecksLTE applies the LTE predicate on the "total_checks" field.
+func TotalChecksLTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldTotalChecks, v))
+}
+
+// OkCountEQ applies the EQ predicate on the "ok_count" field.
+func OkCountEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldOkCount, v))
+}
+
+// OkCountNEQ applies the NEQ predicate on the "ok_count" field.
+func OkCountNEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldOkCount, v))
+}
+
+// OkCountIn applies the In predicate on the "ok_count" field.
+func OkCountIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldOkCount, vs...))
+}
+
+// OkCountNotIn applies the NotIn predicate on the "ok_count" field.
+func OkCountNotIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldOkCount, vs...))
+}
+
+// OkCountGT applies the GT predicate on the "ok_count" field.
+func OkCountGT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldOkCount, v))
+}
+
+// OkCountGTE applies the GTE predicate on the "ok_count" field.
+func OkCountGTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldOkCount, v))
+}
+
+// OkCountLT applies the LT predicate on the "ok_count" field.
+func OkCountLT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldOkCount, v))
+}
+
+// OkCountLTE applies the LTE predicate on the "ok_count" field.
+func OkCountLTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldOkCount, v))
+}
+
+// OperationalCountEQ applies the EQ predicate on the "operational_count" field.
+func OperationalCountEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldOperationalCount, v))
+}
+
+// OperationalCountNEQ applies the NEQ predicate on the "operational_count" field.
+func OperationalCountNEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldOperationalCount, v))
+}
+
+// OperationalCountIn applies the In predicate on the "operational_count" field.
+func OperationalCountIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldOperationalCount, vs...))
+}
+
+// OperationalCountNotIn applies the NotIn predicate on the "operational_count" field.
+func OperationalCountNotIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldOperationalCount, vs...))
+}
+
+// OperationalCountGT applies the GT predicate on the "operational_count" field.
+func OperationalCountGT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldOperationalCount, v))
+}
+
+// OperationalCountGTE applies the GTE predicate on the "operational_count" field.
+func OperationalCountGTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldOperationalCount, v))
+}
+
+// OperationalCountLT applies the LT predicate on the "operational_count" field.
+func OperationalCountLT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldOperationalCount, v))
+}
+
+// OperationalCountLTE applies the LTE predicate on the "operational_count" field.
+func OperationalCountLTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldOperationalCount, v))
+}
+
+// DegradedCountEQ applies the EQ predicate on the "degraded_count" field.
+func DegradedCountEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldDegradedCount, v))
+}
+
+// DegradedCountNEQ applies the NEQ predicate on the "degraded_count" field.
+func DegradedCountNEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldDegradedCount, v))
+}
+
+// DegradedCountIn applies the In predicate on the "degraded_count" field.
+func DegradedCountIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldDegradedCount, vs...))
+}
+
+// DegradedCountNotIn applies the NotIn predicate on the "degraded_count" field.
+func DegradedCountNotIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldDegradedCount, vs...))
+}
+
+// DegradedCountGT applies the GT predicate on the "degraded_count" field.
+func DegradedCountGT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldDegradedCount, v))
+}
+
+// DegradedCountGTE applies the GTE predicate on the "degraded_count" field.
+func DegradedCountGTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldDegradedCount, v))
+}
+
+// DegradedCountLT applies the LT predicate on the "degraded_count" field.
+func DegradedCountLT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldDegradedCount, v))
+}
+
+// DegradedCountLTE applies the LTE predicate on the "degraded_count" field.
+func DegradedCountLTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldDegradedCount, v))
+}
+
+// FailedCountEQ applies the EQ predicate on the "failed_count" field.
+func FailedCountEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldFailedCount, v))
+}
+
+// FailedCountNEQ applies the NEQ predicate on the "failed_count" field.
+func FailedCountNEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldFailedCount, v))
+}
+
+// FailedCountIn applies the In predicate on the "failed_count" field.
+func FailedCountIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldFailedCount, vs...))
+}
+
+// FailedCountNotIn applies the NotIn predicate on the "failed_count" field.
+func FailedCountNotIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldFailedCount, vs...))
+}
+
+// FailedCountGT applies the GT predicate on the "failed_count" field.
+func FailedCountGT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldFailedCount, v))
+}
+
+// FailedCountGTE applies the GTE predicate on the "failed_count" field.
+func FailedCountGTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldFailedCount, v))
+}
+
+// FailedCountLT applies the LT predicate on the "failed_count" field.
+func FailedCountLT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldFailedCount, v))
+}
+
+// FailedCountLTE applies the LTE predicate on the "failed_count" field.
+func FailedCountLTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldFailedCount, v))
+}
+
+// ErrorCountEQ applies the EQ predicate on the "error_count" field.
+func ErrorCountEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldErrorCount, v))
+}
+
+// ErrorCountNEQ applies the NEQ predicate on the "error_count" field.
+func ErrorCountNEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldErrorCount, v))
+}
+
+// ErrorCountIn applies the In predicate on the "error_count" field.
+func ErrorCountIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldErrorCount, vs...))
+}
+
+// ErrorCountNotIn applies the NotIn predicate on the "error_count" field.
+func ErrorCountNotIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldErrorCount, vs...))
+}
+
+// ErrorCountGT applies the GT predicate on the "error_count" field.
+func ErrorCountGT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldErrorCount, v))
+}
+
+// ErrorCountGTE applies the GTE predicate on the "error_count" field.
+func ErrorCountGTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldErrorCount, v))
+}
+
+// ErrorCountLT applies the LT predicate on the "error_count" field.
+func ErrorCountLT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldErrorCount, v))
+}
+
+// ErrorCountLTE applies the LTE predicate on the "error_count" field.
+func ErrorCountLTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldErrorCount, v))
+}
+
+// SumLatencyMsEQ applies the EQ predicate on the "sum_latency_ms" field.
+func SumLatencyMsEQ(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldSumLatencyMs, v))
+}
+
+// SumLatencyMsNEQ applies the NEQ predicate on the "sum_latency_ms" field.
+func SumLatencyMsNEQ(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldSumLatencyMs, v))
+}
+
+// SumLatencyMsIn applies the In predicate on the "sum_latency_ms" field.
+func SumLatencyMsIn(vs ...int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldSumLatencyMs, vs...))
+}
+
+// SumLatencyMsNotIn applies the NotIn predicate on the "sum_latency_ms" field.
+func SumLatencyMsNotIn(vs ...int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldSumLatencyMs, vs...))
+}
+
+// SumLatencyMsGT applies the GT predicate on the "sum_latency_ms" field.
+func SumLatencyMsGT(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldSumLatencyMs, v))
+}
+
+// SumLatencyMsGTE applies the GTE predicate on the "sum_latency_ms" field.
+func SumLatencyMsGTE(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldSumLatencyMs, v))
+}
+
+// SumLatencyMsLT applies the LT predicate on the "sum_latency_ms" field.
+func SumLatencyMsLT(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldSumLatencyMs, v))
+}
+
+// SumLatencyMsLTE applies the LTE predicate on the "sum_latency_ms" field.
+func SumLatencyMsLTE(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldSumLatencyMs, v))
+}
+
+// CountLatencyEQ applies the EQ predicate on the "count_latency" field.
+func CountLatencyEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldCountLatency, v))
+}
+
+// CountLatencyNEQ applies the NEQ predicate on the "count_latency" field.
+func CountLatencyNEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldCountLatency, v))
+}
+
+// CountLatencyIn applies the In predicate on the "count_latency" field.
+func CountLatencyIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldCountLatency, vs...))
+}
+
+// CountLatencyNotIn applies the NotIn predicate on the "count_latency" field.
+func CountLatencyNotIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldCountLatency, vs...))
+}
+
+// CountLatencyGT applies the GT predicate on the "count_latency" field.
+func CountLatencyGT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldCountLatency, v))
+}
+
+// CountLatencyGTE applies the GTE predicate on the "count_latency" field.
+func CountLatencyGTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldCountLatency, v))
+}
+
+// CountLatencyLT applies the LT predicate on the "count_latency" field.
+func CountLatencyLT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldCountLatency, v))
+}
+
+// CountLatencyLTE applies the LTE predicate on the "count_latency" field.
+func CountLatencyLTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldCountLatency, v))
+}
+
+// SumPingLatencyMsEQ applies the EQ predicate on the "sum_ping_latency_ms" field.
+func SumPingLatencyMsEQ(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldSumPingLatencyMs, v))
+}
+
+// SumPingLatencyMsNEQ applies the NEQ predicate on the "sum_ping_latency_ms" field.
+func SumPingLatencyMsNEQ(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldSumPingLatencyMs, v))
+}
+
+// SumPingLatencyMsIn applies the In predicate on the "sum_ping_latency_ms" field.
+func SumPingLatencyMsIn(vs ...int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldSumPingLatencyMs, vs...))
+}
+
+// SumPingLatencyMsNotIn applies the NotIn predicate on the "sum_ping_latency_ms" field.
+func SumPingLatencyMsNotIn(vs ...int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldSumPingLatencyMs, vs...))
+}
+
+// SumPingLatencyMsGT applies the GT predicate on the "sum_ping_latency_ms" field.
+func SumPingLatencyMsGT(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldSumPingLatencyMs, v))
+}
+
+// SumPingLatencyMsGTE applies the GTE predicate on the "sum_ping_latency_ms" field.
+func SumPingLatencyMsGTE(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldSumPingLatencyMs, v))
+}
+
+// SumPingLatencyMsLT applies the LT predicate on the "sum_ping_latency_ms" field.
+func SumPingLatencyMsLT(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldSumPingLatencyMs, v))
+}
+
+// SumPingLatencyMsLTE applies the LTE predicate on the "sum_ping_latency_ms" field.
+func SumPingLatencyMsLTE(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldSumPingLatencyMs, v))
+}
+
+// CountPingLatencyEQ applies the EQ predicate on the "count_ping_latency" field.
+func CountPingLatencyEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldCountPingLatency, v))
+}
+
+// CountPingLatencyNEQ applies the NEQ predicate on the "count_ping_latency" field.
+func CountPingLatencyNEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldCountPingLatency, v))
+}
+
+// CountPingLatencyIn applies the In predicate on the "count_ping_latency" field.
+func CountPingLatencyIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldCountPingLatency, vs...))
+}
+
+// CountPingLatencyNotIn applies the NotIn predicate on the "count_ping_latency" field.
+func CountPingLatencyNotIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldCountPingLatency, vs...))
+}
+
+// CountPingLatencyGT applies the GT predicate on the "count_ping_latency" field.
+func CountPingLatencyGT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldCountPingLatency, v))
+}
+
+// CountPingLatencyGTE applies the GTE predicate on the "count_ping_latency" field.
+func CountPingLatencyGTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldCountPingLatency, v))
+}
+
+// CountPingLatencyLT applies the LT predicate on the "count_ping_latency" field.
+func CountPingLatencyLT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldCountPingLatency, v))
+}
+
+// CountPingLatencyLTE applies the LTE predicate on the "count_ping_latency" field.
+func CountPingLatencyLTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldCountPingLatency, v))
+}
+
+// ComputedAtEQ applies the EQ predicate on the "computed_at" field.
+func ComputedAtEQ(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldComputedAt, v))
+}
+
+// ComputedAtNEQ applies the NEQ predicate on the "computed_at" field.
+func ComputedAtNEQ(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldComputedAt, v))
+}
+
+// ComputedAtIn applies the In predicate on the "computed_at" field.
+func ComputedAtIn(vs ...time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldComputedAt, vs...))
+}
+
+// ComputedAtNotIn applies the NotIn predicate on the "computed_at" field.
+func ComputedAtNotIn(vs ...time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldComputedAt, vs...))
+}
+
+// ComputedAtGT applies the GT predicate on the "computed_at" field.
+func ComputedAtGT(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldComputedAt, v))
+}
+
+// ComputedAtGTE applies the GTE predicate on the "computed_at" field.
+func ComputedAtGTE(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldComputedAt, v))
+}
+
+// ComputedAtLT applies the LT predicate on the "computed_at" field.
+func ComputedAtLT(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldComputedAt, v))
+}
+
+// ComputedAtLTE applies the LTE predicate on the "computed_at" field.
+func ComputedAtLTE(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldComputedAt, v))
+}
+
+// HasMonitor applies the HasEdge predicate on the "monitor" edge.
+func HasMonitor() predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, MonitorTable, MonitorColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasMonitorWith applies the HasEdge predicate on the "monitor" edge with a given conditions (other predicates).
+func HasMonitorWith(preds ...predicate.ChannelMonitor) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(func(s *sql.Selector) {
+ step := newMonitorStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// And groups predicates with the AND operator between them.
+func And(predicates ...predicate.ChannelMonitorDailyRollup) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.AndPredicates(predicates...))
+}
+
+// Or groups predicates with the OR operator between them.
+func Or(predicates ...predicate.ChannelMonitorDailyRollup) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.OrPredicates(predicates...))
+}
+
+// Not applies the not operator on the given predicate.
+func Not(p predicate.ChannelMonitorDailyRollup) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.NotPredicates(p))
+}
diff --git a/backend/ent/channelmonitordailyrollup_create.go b/backend/ent/channelmonitordailyrollup_create.go
new file mode 100644
index 00000000..c4850751
--- /dev/null
+++ b/backend/ent/channelmonitordailyrollup_create.go
@@ -0,0 +1,1593 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
+)
+
+// ChannelMonitorDailyRollupCreate is the builder for creating a ChannelMonitorDailyRollup entity.
+type ChannelMonitorDailyRollupCreate struct {
+ config
+ mutation *ChannelMonitorDailyRollupMutation
+ hooks []Hook
+ conflict []sql.ConflictOption
+}
+
+// SetDeletedAt sets the "deleted_at" field.
+func (_c *ChannelMonitorDailyRollupCreate) SetDeletedAt(v time.Time) *ChannelMonitorDailyRollupCreate {
+ _c.mutation.SetDeletedAt(v)
+ return _c
+}
+
+// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil.
+func (_c *ChannelMonitorDailyRollupCreate) SetNillableDeletedAt(v *time.Time) *ChannelMonitorDailyRollupCreate {
+ if v != nil {
+ _c.SetDeletedAt(*v)
+ }
+ return _c
+}
+
+// SetMonitorID sets the "monitor_id" field.
+func (_c *ChannelMonitorDailyRollupCreate) SetMonitorID(v int64) *ChannelMonitorDailyRollupCreate {
+ _c.mutation.SetMonitorID(v)
+ return _c
+}
+
+// SetModel sets the "model" field.
+func (_c *ChannelMonitorDailyRollupCreate) SetModel(v string) *ChannelMonitorDailyRollupCreate {
+ _c.mutation.SetModel(v)
+ return _c
+}
+
+// SetBucketDate sets the "bucket_date" field.
+func (_c *ChannelMonitorDailyRollupCreate) SetBucketDate(v time.Time) *ChannelMonitorDailyRollupCreate {
+ _c.mutation.SetBucketDate(v)
+ return _c
+}
+
+// SetTotalChecks sets the "total_checks" field.
+func (_c *ChannelMonitorDailyRollupCreate) SetTotalChecks(v int) *ChannelMonitorDailyRollupCreate {
+ _c.mutation.SetTotalChecks(v)
+ return _c
+}
+
+// SetNillableTotalChecks sets the "total_checks" field if the given value is not nil.
+func (_c *ChannelMonitorDailyRollupCreate) SetNillableTotalChecks(v *int) *ChannelMonitorDailyRollupCreate {
+ if v != nil {
+ _c.SetTotalChecks(*v)
+ }
+ return _c
+}
+
+// SetOkCount sets the "ok_count" field.
+func (_c *ChannelMonitorDailyRollupCreate) SetOkCount(v int) *ChannelMonitorDailyRollupCreate {
+ _c.mutation.SetOkCount(v)
+ return _c
+}
+
+// SetNillableOkCount sets the "ok_count" field if the given value is not nil.
+func (_c *ChannelMonitorDailyRollupCreate) SetNillableOkCount(v *int) *ChannelMonitorDailyRollupCreate {
+ if v != nil {
+ _c.SetOkCount(*v)
+ }
+ return _c
+}
+
+// SetOperationalCount sets the "operational_count" field.
+func (_c *ChannelMonitorDailyRollupCreate) SetOperationalCount(v int) *ChannelMonitorDailyRollupCreate {
+ _c.mutation.SetOperationalCount(v)
+ return _c
+}
+
+// SetNillableOperationalCount sets the "operational_count" field if the given value is not nil.
+func (_c *ChannelMonitorDailyRollupCreate) SetNillableOperationalCount(v *int) *ChannelMonitorDailyRollupCreate {
+ if v != nil {
+ _c.SetOperationalCount(*v)
+ }
+ return _c
+}
+
+// SetDegradedCount sets the "degraded_count" field.
+func (_c *ChannelMonitorDailyRollupCreate) SetDegradedCount(v int) *ChannelMonitorDailyRollupCreate {
+ _c.mutation.SetDegradedCount(v)
+ return _c
+}
+
+// SetNillableDegradedCount sets the "degraded_count" field if the given value is not nil.
+func (_c *ChannelMonitorDailyRollupCreate) SetNillableDegradedCount(v *int) *ChannelMonitorDailyRollupCreate {
+ if v != nil {
+ _c.SetDegradedCount(*v)
+ }
+ return _c
+}
+
+// SetFailedCount sets the "failed_count" field.
+func (_c *ChannelMonitorDailyRollupCreate) SetFailedCount(v int) *ChannelMonitorDailyRollupCreate {
+ _c.mutation.SetFailedCount(v)
+ return _c
+}
+
+// SetNillableFailedCount sets the "failed_count" field if the given value is not nil.
+func (_c *ChannelMonitorDailyRollupCreate) SetNillableFailedCount(v *int) *ChannelMonitorDailyRollupCreate {
+ if v != nil {
+ _c.SetFailedCount(*v)
+ }
+ return _c
+}
+
+// SetErrorCount sets the "error_count" field.
+func (_c *ChannelMonitorDailyRollupCreate) SetErrorCount(v int) *ChannelMonitorDailyRollupCreate {
+ _c.mutation.SetErrorCount(v)
+ return _c
+}
+
+// SetNillableErrorCount sets the "error_count" field if the given value is not nil.
+func (_c *ChannelMonitorDailyRollupCreate) SetNillableErrorCount(v *int) *ChannelMonitorDailyRollupCreate {
+ if v != nil {
+ _c.SetErrorCount(*v)
+ }
+ return _c
+}
+
+// SetSumLatencyMs sets the "sum_latency_ms" field.
+func (_c *ChannelMonitorDailyRollupCreate) SetSumLatencyMs(v int64) *ChannelMonitorDailyRollupCreate {
+ _c.mutation.SetSumLatencyMs(v)
+ return _c
+}
+
+// SetNillableSumLatencyMs sets the "sum_latency_ms" field if the given value is not nil.
+func (_c *ChannelMonitorDailyRollupCreate) SetNillableSumLatencyMs(v *int64) *ChannelMonitorDailyRollupCreate {
+ if v != nil {
+ _c.SetSumLatencyMs(*v)
+ }
+ return _c
+}
+
+// SetCountLatency sets the "count_latency" field.
+func (_c *ChannelMonitorDailyRollupCreate) SetCountLatency(v int) *ChannelMonitorDailyRollupCreate {
+ _c.mutation.SetCountLatency(v)
+ return _c
+}
+
+// SetNillableCountLatency sets the "count_latency" field if the given value is not nil.
+func (_c *ChannelMonitorDailyRollupCreate) SetNillableCountLatency(v *int) *ChannelMonitorDailyRollupCreate {
+ if v != nil {
+ _c.SetCountLatency(*v)
+ }
+ return _c
+}
+
+// SetSumPingLatencyMs sets the "sum_ping_latency_ms" field.
+func (_c *ChannelMonitorDailyRollupCreate) SetSumPingLatencyMs(v int64) *ChannelMonitorDailyRollupCreate {
+ _c.mutation.SetSumPingLatencyMs(v)
+ return _c
+}
+
+// SetNillableSumPingLatencyMs sets the "sum_ping_latency_ms" field if the given value is not nil.
+func (_c *ChannelMonitorDailyRollupCreate) SetNillableSumPingLatencyMs(v *int64) *ChannelMonitorDailyRollupCreate {
+ if v != nil {
+ _c.SetSumPingLatencyMs(*v)
+ }
+ return _c
+}
+
+// SetCountPingLatency sets the "count_ping_latency" field.
+func (_c *ChannelMonitorDailyRollupCreate) SetCountPingLatency(v int) *ChannelMonitorDailyRollupCreate {
+ _c.mutation.SetCountPingLatency(v)
+ return _c
+}
+
+// SetNillableCountPingLatency sets the "count_ping_latency" field if the given value is not nil.
+func (_c *ChannelMonitorDailyRollupCreate) SetNillableCountPingLatency(v *int) *ChannelMonitorDailyRollupCreate {
+ if v != nil {
+ _c.SetCountPingLatency(*v)
+ }
+ return _c
+}
+
+// SetComputedAt sets the "computed_at" field.
+func (_c *ChannelMonitorDailyRollupCreate) SetComputedAt(v time.Time) *ChannelMonitorDailyRollupCreate {
+ _c.mutation.SetComputedAt(v)
+ return _c
+}
+
+// SetNillableComputedAt sets the "computed_at" field if the given value is not nil.
+func (_c *ChannelMonitorDailyRollupCreate) SetNillableComputedAt(v *time.Time) *ChannelMonitorDailyRollupCreate {
+ if v != nil {
+ _c.SetComputedAt(*v)
+ }
+ return _c
+}
+
+// SetMonitor sets the "monitor" edge to the ChannelMonitor entity.
+func (_c *ChannelMonitorDailyRollupCreate) SetMonitor(v *ChannelMonitor) *ChannelMonitorDailyRollupCreate {
+ return _c.SetMonitorID(v.ID)
+}
+
+// Mutation returns the ChannelMonitorDailyRollupMutation object of the builder.
+func (_c *ChannelMonitorDailyRollupCreate) Mutation() *ChannelMonitorDailyRollupMutation {
+ return _c.mutation
+}
+
+// Save creates the ChannelMonitorDailyRollup in the database.
+func (_c *ChannelMonitorDailyRollupCreate) Save(ctx context.Context) (*ChannelMonitorDailyRollup, error) {
+ if err := _c.defaults(); err != nil {
+ return nil, err
+ }
+ return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
+}
+
+// SaveX calls Save and panics if Save returns an error.
+func (_c *ChannelMonitorDailyRollupCreate) SaveX(ctx context.Context) *ChannelMonitorDailyRollup {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *ChannelMonitorDailyRollupCreate) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *ChannelMonitorDailyRollupCreate) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_c *ChannelMonitorDailyRollupCreate) defaults() error {
+ if _, ok := _c.mutation.TotalChecks(); !ok {
+ v := channelmonitordailyrollup.DefaultTotalChecks
+ _c.mutation.SetTotalChecks(v)
+ }
+ if _, ok := _c.mutation.OkCount(); !ok {
+ v := channelmonitordailyrollup.DefaultOkCount
+ _c.mutation.SetOkCount(v)
+ }
+ if _, ok := _c.mutation.OperationalCount(); !ok {
+ v := channelmonitordailyrollup.DefaultOperationalCount
+ _c.mutation.SetOperationalCount(v)
+ }
+ if _, ok := _c.mutation.DegradedCount(); !ok {
+ v := channelmonitordailyrollup.DefaultDegradedCount
+ _c.mutation.SetDegradedCount(v)
+ }
+ if _, ok := _c.mutation.FailedCount(); !ok {
+ v := channelmonitordailyrollup.DefaultFailedCount
+ _c.mutation.SetFailedCount(v)
+ }
+ if _, ok := _c.mutation.ErrorCount(); !ok {
+ v := channelmonitordailyrollup.DefaultErrorCount
+ _c.mutation.SetErrorCount(v)
+ }
+ if _, ok := _c.mutation.SumLatencyMs(); !ok {
+ v := channelmonitordailyrollup.DefaultSumLatencyMs
+ _c.mutation.SetSumLatencyMs(v)
+ }
+ if _, ok := _c.mutation.CountLatency(); !ok {
+ v := channelmonitordailyrollup.DefaultCountLatency
+ _c.mutation.SetCountLatency(v)
+ }
+ if _, ok := _c.mutation.SumPingLatencyMs(); !ok {
+ v := channelmonitordailyrollup.DefaultSumPingLatencyMs
+ _c.mutation.SetSumPingLatencyMs(v)
+ }
+ if _, ok := _c.mutation.CountPingLatency(); !ok {
+ v := channelmonitordailyrollup.DefaultCountPingLatency
+ _c.mutation.SetCountPingLatency(v)
+ }
+ if _, ok := _c.mutation.ComputedAt(); !ok {
+ if channelmonitordailyrollup.DefaultComputedAt == nil {
+ return fmt.Errorf("ent: uninitialized channelmonitordailyrollup.DefaultComputedAt (forgotten import ent/runtime?)")
+ }
+ v := channelmonitordailyrollup.DefaultComputedAt()
+ _c.mutation.SetComputedAt(v)
+ }
+ return nil
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_c *ChannelMonitorDailyRollupCreate) check() error {
+ if _, ok := _c.mutation.MonitorID(); !ok {
+ return &ValidationError{Name: "monitor_id", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.monitor_id"`)}
+ }
+ if _, ok := _c.mutation.Model(); !ok {
+ return &ValidationError{Name: "model", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.model"`)}
+ }
+ if v, ok := _c.mutation.Model(); ok {
+ if err := channelmonitordailyrollup.ModelValidator(v); err != nil {
+ return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorDailyRollup.model": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.BucketDate(); !ok {
+ return &ValidationError{Name: "bucket_date", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.bucket_date"`)}
+ }
+ if _, ok := _c.mutation.TotalChecks(); !ok {
+ return &ValidationError{Name: "total_checks", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.total_checks"`)}
+ }
+ if _, ok := _c.mutation.OkCount(); !ok {
+ return &ValidationError{Name: "ok_count", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.ok_count"`)}
+ }
+ if _, ok := _c.mutation.OperationalCount(); !ok {
+ return &ValidationError{Name: "operational_count", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.operational_count"`)}
+ }
+ if _, ok := _c.mutation.DegradedCount(); !ok {
+ return &ValidationError{Name: "degraded_count", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.degraded_count"`)}
+ }
+ if _, ok := _c.mutation.FailedCount(); !ok {
+ return &ValidationError{Name: "failed_count", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.failed_count"`)}
+ }
+ if _, ok := _c.mutation.ErrorCount(); !ok {
+ return &ValidationError{Name: "error_count", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.error_count"`)}
+ }
+ if _, ok := _c.mutation.SumLatencyMs(); !ok {
+ return &ValidationError{Name: "sum_latency_ms", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.sum_latency_ms"`)}
+ }
+ if _, ok := _c.mutation.CountLatency(); !ok {
+ return &ValidationError{Name: "count_latency", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.count_latency"`)}
+ }
+ if _, ok := _c.mutation.SumPingLatencyMs(); !ok {
+ return &ValidationError{Name: "sum_ping_latency_ms", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.sum_ping_latency_ms"`)}
+ }
+ if _, ok := _c.mutation.CountPingLatency(); !ok {
+ return &ValidationError{Name: "count_ping_latency", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.count_ping_latency"`)}
+ }
+ if _, ok := _c.mutation.ComputedAt(); !ok {
+ return &ValidationError{Name: "computed_at", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.computed_at"`)}
+ }
+ if len(_c.mutation.MonitorIDs()) == 0 {
+ return &ValidationError{Name: "monitor", err: errors.New(`ent: missing required edge "ChannelMonitorDailyRollup.monitor"`)}
+ }
+ return nil
+}
+
+func (_c *ChannelMonitorDailyRollupCreate) sqlSave(ctx context.Context) (*ChannelMonitorDailyRollup, error) {
+ if err := _c.check(); err != nil {
+ return nil, err
+ }
+ _node, _spec := _c.createSpec()
+ if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ id := _spec.ID.Value.(int64)
+ _node.ID = int64(id)
+ _c.mutation.id = &_node.ID
+ _c.mutation.done = true
+ return _node, nil
+}
+
+func (_c *ChannelMonitorDailyRollupCreate) createSpec() (*ChannelMonitorDailyRollup, *sqlgraph.CreateSpec) {
+ var (
+ _node = &ChannelMonitorDailyRollup{config: _c.config}
+ _spec = sqlgraph.NewCreateSpec(channelmonitordailyrollup.Table, sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64))
+ )
+ _spec.OnConflict = _c.conflict
+ if value, ok := _c.mutation.DeletedAt(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldDeletedAt, field.TypeTime, value)
+ _node.DeletedAt = &value
+ }
+ if value, ok := _c.mutation.Model(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldModel, field.TypeString, value)
+ _node.Model = value
+ }
+ if value, ok := _c.mutation.BucketDate(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldBucketDate, field.TypeTime, value)
+ _node.BucketDate = value
+ }
+ if value, ok := _c.mutation.TotalChecks(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldTotalChecks, field.TypeInt, value)
+ _node.TotalChecks = value
+ }
+ if value, ok := _c.mutation.OkCount(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldOkCount, field.TypeInt, value)
+ _node.OkCount = value
+ }
+ if value, ok := _c.mutation.OperationalCount(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldOperationalCount, field.TypeInt, value)
+ _node.OperationalCount = value
+ }
+ if value, ok := _c.mutation.DegradedCount(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldDegradedCount, field.TypeInt, value)
+ _node.DegradedCount = value
+ }
+ if value, ok := _c.mutation.FailedCount(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldFailedCount, field.TypeInt, value)
+ _node.FailedCount = value
+ }
+ if value, ok := _c.mutation.ErrorCount(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldErrorCount, field.TypeInt, value)
+ _node.ErrorCount = value
+ }
+ if value, ok := _c.mutation.SumLatencyMs(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldSumLatencyMs, field.TypeInt64, value)
+ _node.SumLatencyMs = value
+ }
+ if value, ok := _c.mutation.CountLatency(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldCountLatency, field.TypeInt, value)
+ _node.CountLatency = value
+ }
+ if value, ok := _c.mutation.SumPingLatencyMs(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldSumPingLatencyMs, field.TypeInt64, value)
+ _node.SumPingLatencyMs = value
+ }
+ if value, ok := _c.mutation.CountPingLatency(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldCountPingLatency, field.TypeInt, value)
+ _node.CountPingLatency = value
+ }
+ if value, ok := _c.mutation.ComputedAt(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldComputedAt, field.TypeTime, value)
+ _node.ComputedAt = value
+ }
+ if nodes := _c.mutation.MonitorIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: channelmonitordailyrollup.MonitorTable,
+ Columns: []string{channelmonitordailyrollup.MonitorColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _node.MonitorID = nodes[0]
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ return _node, _spec
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.ChannelMonitorDailyRollup.Create().
+// SetDeletedAt(v).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.ChannelMonitorDailyRollupUpsert) {
+// SetDeletedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *ChannelMonitorDailyRollupCreate) OnConflict(opts ...sql.ConflictOption) *ChannelMonitorDailyRollupUpsertOne {
+ _c.conflict = opts
+ return &ChannelMonitorDailyRollupUpsertOne{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.ChannelMonitorDailyRollup.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *ChannelMonitorDailyRollupCreate) OnConflictColumns(columns ...string) *ChannelMonitorDailyRollupUpsertOne {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &ChannelMonitorDailyRollupUpsertOne{
+ create: _c,
+ }
+}
+
+type (
+ // ChannelMonitorDailyRollupUpsertOne is the builder for "upsert"-ing
+ // one ChannelMonitorDailyRollup node.
+ ChannelMonitorDailyRollupUpsertOne struct {
+ create *ChannelMonitorDailyRollupCreate
+ }
+
+ // ChannelMonitorDailyRollupUpsert is the "OnConflict" setter.
+ ChannelMonitorDailyRollupUpsert struct {
+ *sql.UpdateSet
+ }
+)
+
+// SetDeletedAt sets the "deleted_at" field.
+func (u *ChannelMonitorDailyRollupUpsert) SetDeletedAt(v time.Time) *ChannelMonitorDailyRollupUpsert {
+ u.Set(channelmonitordailyrollup.FieldDeletedAt, v)
+ return u
+}
+
+// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsert) UpdateDeletedAt() *ChannelMonitorDailyRollupUpsert {
+ u.SetExcluded(channelmonitordailyrollup.FieldDeletedAt)
+ return u
+}
+
+// ClearDeletedAt clears the value of the "deleted_at" field.
+func (u *ChannelMonitorDailyRollupUpsert) ClearDeletedAt() *ChannelMonitorDailyRollupUpsert {
+ u.SetNull(channelmonitordailyrollup.FieldDeletedAt)
+ return u
+}
+
+// SetMonitorID sets the "monitor_id" field.
+func (u *ChannelMonitorDailyRollupUpsert) SetMonitorID(v int64) *ChannelMonitorDailyRollupUpsert {
+ u.Set(channelmonitordailyrollup.FieldMonitorID, v)
+ return u
+}
+
+// UpdateMonitorID sets the "monitor_id" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsert) UpdateMonitorID() *ChannelMonitorDailyRollupUpsert {
+ u.SetExcluded(channelmonitordailyrollup.FieldMonitorID)
+ return u
+}
+
+// SetModel sets the "model" field.
+func (u *ChannelMonitorDailyRollupUpsert) SetModel(v string) *ChannelMonitorDailyRollupUpsert {
+ u.Set(channelmonitordailyrollup.FieldModel, v)
+ return u
+}
+
+// UpdateModel sets the "model" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsert) UpdateModel() *ChannelMonitorDailyRollupUpsert {
+ u.SetExcluded(channelmonitordailyrollup.FieldModel)
+ return u
+}
+
+// SetBucketDate sets the "bucket_date" field.
+func (u *ChannelMonitorDailyRollupUpsert) SetBucketDate(v time.Time) *ChannelMonitorDailyRollupUpsert {
+ u.Set(channelmonitordailyrollup.FieldBucketDate, v)
+ return u
+}
+
+// UpdateBucketDate sets the "bucket_date" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsert) UpdateBucketDate() *ChannelMonitorDailyRollupUpsert {
+ u.SetExcluded(channelmonitordailyrollup.FieldBucketDate)
+ return u
+}
+
+// SetTotalChecks sets the "total_checks" field.
+func (u *ChannelMonitorDailyRollupUpsert) SetTotalChecks(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Set(channelmonitordailyrollup.FieldTotalChecks, v)
+ return u
+}
+
+// UpdateTotalChecks sets the "total_checks" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsert) UpdateTotalChecks() *ChannelMonitorDailyRollupUpsert {
+ u.SetExcluded(channelmonitordailyrollup.FieldTotalChecks)
+ return u
+}
+
+// AddTotalChecks adds v to the "total_checks" field.
+func (u *ChannelMonitorDailyRollupUpsert) AddTotalChecks(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Add(channelmonitordailyrollup.FieldTotalChecks, v)
+ return u
+}
+
+// SetOkCount sets the "ok_count" field.
+func (u *ChannelMonitorDailyRollupUpsert) SetOkCount(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Set(channelmonitordailyrollup.FieldOkCount, v)
+ return u
+}
+
+// UpdateOkCount sets the "ok_count" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsert) UpdateOkCount() *ChannelMonitorDailyRollupUpsert {
+ u.SetExcluded(channelmonitordailyrollup.FieldOkCount)
+ return u
+}
+
+// AddOkCount adds v to the "ok_count" field.
+func (u *ChannelMonitorDailyRollupUpsert) AddOkCount(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Add(channelmonitordailyrollup.FieldOkCount, v)
+ return u
+}
+
+// SetOperationalCount sets the "operational_count" field.
+func (u *ChannelMonitorDailyRollupUpsert) SetOperationalCount(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Set(channelmonitordailyrollup.FieldOperationalCount, v)
+ return u
+}
+
+// UpdateOperationalCount sets the "operational_count" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsert) UpdateOperationalCount() *ChannelMonitorDailyRollupUpsert {
+ u.SetExcluded(channelmonitordailyrollup.FieldOperationalCount)
+ return u
+}
+
+// AddOperationalCount adds v to the "operational_count" field.
+func (u *ChannelMonitorDailyRollupUpsert) AddOperationalCount(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Add(channelmonitordailyrollup.FieldOperationalCount, v)
+ return u
+}
+
+// SetDegradedCount sets the "degraded_count" field.
+func (u *ChannelMonitorDailyRollupUpsert) SetDegradedCount(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Set(channelmonitordailyrollup.FieldDegradedCount, v)
+ return u
+}
+
+// UpdateDegradedCount sets the "degraded_count" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsert) UpdateDegradedCount() *ChannelMonitorDailyRollupUpsert {
+ u.SetExcluded(channelmonitordailyrollup.FieldDegradedCount)
+ return u
+}
+
+// AddDegradedCount adds v to the "degraded_count" field.
+func (u *ChannelMonitorDailyRollupUpsert) AddDegradedCount(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Add(channelmonitordailyrollup.FieldDegradedCount, v)
+ return u
+}
+
+// SetFailedCount sets the "failed_count" field.
+func (u *ChannelMonitorDailyRollupUpsert) SetFailedCount(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Set(channelmonitordailyrollup.FieldFailedCount, v)
+ return u
+}
+
+// UpdateFailedCount sets the "failed_count" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsert) UpdateFailedCount() *ChannelMonitorDailyRollupUpsert {
+ u.SetExcluded(channelmonitordailyrollup.FieldFailedCount)
+ return u
+}
+
+// AddFailedCount adds v to the "failed_count" field.
+func (u *ChannelMonitorDailyRollupUpsert) AddFailedCount(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Add(channelmonitordailyrollup.FieldFailedCount, v)
+ return u
+}
+
+// SetErrorCount sets the "error_count" field.
+func (u *ChannelMonitorDailyRollupUpsert) SetErrorCount(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Set(channelmonitordailyrollup.FieldErrorCount, v)
+ return u
+}
+
+// UpdateErrorCount sets the "error_count" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsert) UpdateErrorCount() *ChannelMonitorDailyRollupUpsert {
+ u.SetExcluded(channelmonitordailyrollup.FieldErrorCount)
+ return u
+}
+
+// AddErrorCount adds v to the "error_count" field.
+func (u *ChannelMonitorDailyRollupUpsert) AddErrorCount(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Add(channelmonitordailyrollup.FieldErrorCount, v)
+ return u
+}
+
+// SetSumLatencyMs sets the "sum_latency_ms" field.
+func (u *ChannelMonitorDailyRollupUpsert) SetSumLatencyMs(v int64) *ChannelMonitorDailyRollupUpsert {
+ u.Set(channelmonitordailyrollup.FieldSumLatencyMs, v)
+ return u
+}
+
+// UpdateSumLatencyMs sets the "sum_latency_ms" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsert) UpdateSumLatencyMs() *ChannelMonitorDailyRollupUpsert {
+ u.SetExcluded(channelmonitordailyrollup.FieldSumLatencyMs)
+ return u
+}
+
+// AddSumLatencyMs adds v to the "sum_latency_ms" field.
+func (u *ChannelMonitorDailyRollupUpsert) AddSumLatencyMs(v int64) *ChannelMonitorDailyRollupUpsert {
+ u.Add(channelmonitordailyrollup.FieldSumLatencyMs, v)
+ return u
+}
+
+// SetCountLatency sets the "count_latency" field.
+func (u *ChannelMonitorDailyRollupUpsert) SetCountLatency(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Set(channelmonitordailyrollup.FieldCountLatency, v)
+ return u
+}
+
+// UpdateCountLatency sets the "count_latency" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsert) UpdateCountLatency() *ChannelMonitorDailyRollupUpsert {
+ u.SetExcluded(channelmonitordailyrollup.FieldCountLatency)
+ return u
+}
+
+// AddCountLatency adds v to the "count_latency" field.
+func (u *ChannelMonitorDailyRollupUpsert) AddCountLatency(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Add(channelmonitordailyrollup.FieldCountLatency, v)
+ return u
+}
+
+// SetSumPingLatencyMs sets the "sum_ping_latency_ms" field.
+func (u *ChannelMonitorDailyRollupUpsert) SetSumPingLatencyMs(v int64) *ChannelMonitorDailyRollupUpsert {
+ u.Set(channelmonitordailyrollup.FieldSumPingLatencyMs, v)
+ return u
+}
+
+// UpdateSumPingLatencyMs sets the "sum_ping_latency_ms" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsert) UpdateSumPingLatencyMs() *ChannelMonitorDailyRollupUpsert {
+ u.SetExcluded(channelmonitordailyrollup.FieldSumPingLatencyMs)
+ return u
+}
+
+// AddSumPingLatencyMs adds v to the "sum_ping_latency_ms" field.
+func (u *ChannelMonitorDailyRollupUpsert) AddSumPingLatencyMs(v int64) *ChannelMonitorDailyRollupUpsert {
+ u.Add(channelmonitordailyrollup.FieldSumPingLatencyMs, v)
+ return u
+}
+
+// SetCountPingLatency sets the "count_ping_latency" field.
+func (u *ChannelMonitorDailyRollupUpsert) SetCountPingLatency(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Set(channelmonitordailyrollup.FieldCountPingLatency, v)
+ return u
+}
+
+// UpdateCountPingLatency sets the "count_ping_latency" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsert) UpdateCountPingLatency() *ChannelMonitorDailyRollupUpsert {
+ u.SetExcluded(channelmonitordailyrollup.FieldCountPingLatency)
+ return u
+}
+
+// AddCountPingLatency adds v to the "count_ping_latency" field.
+func (u *ChannelMonitorDailyRollupUpsert) AddCountPingLatency(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Add(channelmonitordailyrollup.FieldCountPingLatency, v)
+ return u
+}
+
+// SetComputedAt sets the "computed_at" field.
+func (u *ChannelMonitorDailyRollupUpsert) SetComputedAt(v time.Time) *ChannelMonitorDailyRollupUpsert {
+ u.Set(channelmonitordailyrollup.FieldComputedAt, v)
+ return u
+}
+
+// UpdateComputedAt sets the "computed_at" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsert) UpdateComputedAt() *ChannelMonitorDailyRollupUpsert {
+ u.SetExcluded(channelmonitordailyrollup.FieldComputedAt)
+ return u
+}
+
+// UpdateNewValues updates the mutable fields using the new values that were set on create.
+// Using this option is equivalent to using:
+//
+// client.ChannelMonitorDailyRollup.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *ChannelMonitorDailyRollupUpsertOne) UpdateNewValues() *ChannelMonitorDailyRollupUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.ChannelMonitorDailyRollup.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *ChannelMonitorDailyRollupUpsertOne) Ignore() *ChannelMonitorDailyRollupUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *ChannelMonitorDailyRollupUpsertOne) DoNothing() *ChannelMonitorDailyRollupUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the ChannelMonitorDailyRollupCreate.OnConflict
+// documentation for more info.
+func (u *ChannelMonitorDailyRollupUpsertOne) Update(set func(*ChannelMonitorDailyRollupUpsert)) *ChannelMonitorDailyRollupUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&ChannelMonitorDailyRollupUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetDeletedAt sets the "deleted_at" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) SetDeletedAt(v time.Time) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetDeletedAt(v)
+ })
+}
+
+// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertOne) UpdateDeletedAt() *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateDeletedAt()
+ })
+}
+
+// ClearDeletedAt clears the value of the "deleted_at" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) ClearDeletedAt() *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.ClearDeletedAt()
+ })
+}
+
+// SetMonitorID sets the "monitor_id" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) SetMonitorID(v int64) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetMonitorID(v)
+ })
+}
+
+// UpdateMonitorID sets the "monitor_id" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertOne) UpdateMonitorID() *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateMonitorID()
+ })
+}
+
+// SetModel sets the "model" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) SetModel(v string) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetModel(v)
+ })
+}
+
+// UpdateModel sets the "model" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertOne) UpdateModel() *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateModel()
+ })
+}
+
+// SetBucketDate sets the "bucket_date" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) SetBucketDate(v time.Time) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetBucketDate(v)
+ })
+}
+
+// UpdateBucketDate sets the "bucket_date" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertOne) UpdateBucketDate() *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateBucketDate()
+ })
+}
+
+// SetTotalChecks sets the "total_checks" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) SetTotalChecks(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetTotalChecks(v)
+ })
+}
+
+// AddTotalChecks adds v to the "total_checks" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) AddTotalChecks(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddTotalChecks(v)
+ })
+}
+
+// UpdateTotalChecks sets the "total_checks" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertOne) UpdateTotalChecks() *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateTotalChecks()
+ })
+}
+
+// SetOkCount sets the "ok_count" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) SetOkCount(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetOkCount(v)
+ })
+}
+
+// AddOkCount adds v to the "ok_count" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) AddOkCount(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddOkCount(v)
+ })
+}
+
+// UpdateOkCount sets the "ok_count" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertOne) UpdateOkCount() *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateOkCount()
+ })
+}
+
+// SetOperationalCount sets the "operational_count" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) SetOperationalCount(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetOperationalCount(v)
+ })
+}
+
+// AddOperationalCount adds v to the "operational_count" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) AddOperationalCount(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddOperationalCount(v)
+ })
+}
+
+// UpdateOperationalCount sets the "operational_count" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertOne) UpdateOperationalCount() *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateOperationalCount()
+ })
+}
+
+// SetDegradedCount sets the "degraded_count" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) SetDegradedCount(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetDegradedCount(v)
+ })
+}
+
+// AddDegradedCount adds v to the "degraded_count" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) AddDegradedCount(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddDegradedCount(v)
+ })
+}
+
+// UpdateDegradedCount sets the "degraded_count" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertOne) UpdateDegradedCount() *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateDegradedCount()
+ })
+}
+
+// SetFailedCount sets the "failed_count" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) SetFailedCount(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetFailedCount(v)
+ })
+}
+
+// AddFailedCount adds v to the "failed_count" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) AddFailedCount(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddFailedCount(v)
+ })
+}
+
+// UpdateFailedCount sets the "failed_count" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertOne) UpdateFailedCount() *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateFailedCount()
+ })
+}
+
+// SetErrorCount sets the "error_count" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) SetErrorCount(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetErrorCount(v)
+ })
+}
+
+// AddErrorCount adds v to the "error_count" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) AddErrorCount(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddErrorCount(v)
+ })
+}
+
+// UpdateErrorCount sets the "error_count" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertOne) UpdateErrorCount() *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateErrorCount()
+ })
+}
+
+// SetSumLatencyMs sets the "sum_latency_ms" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) SetSumLatencyMs(v int64) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetSumLatencyMs(v)
+ })
+}
+
+// AddSumLatencyMs adds v to the "sum_latency_ms" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) AddSumLatencyMs(v int64) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddSumLatencyMs(v)
+ })
+}
+
+// UpdateSumLatencyMs sets the "sum_latency_ms" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertOne) UpdateSumLatencyMs() *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateSumLatencyMs()
+ })
+}
+
+// SetCountLatency sets the "count_latency" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) SetCountLatency(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetCountLatency(v)
+ })
+}
+
+// AddCountLatency adds v to the "count_latency" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) AddCountLatency(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddCountLatency(v)
+ })
+}
+
+// UpdateCountLatency sets the "count_latency" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertOne) UpdateCountLatency() *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateCountLatency()
+ })
+}
+
+// SetSumPingLatencyMs sets the "sum_ping_latency_ms" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) SetSumPingLatencyMs(v int64) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetSumPingLatencyMs(v)
+ })
+}
+
+// AddSumPingLatencyMs adds v to the "sum_ping_latency_ms" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) AddSumPingLatencyMs(v int64) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddSumPingLatencyMs(v)
+ })
+}
+
+// UpdateSumPingLatencyMs sets the "sum_ping_latency_ms" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertOne) UpdateSumPingLatencyMs() *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateSumPingLatencyMs()
+ })
+}
+
+// SetCountPingLatency sets the "count_ping_latency" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) SetCountPingLatency(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetCountPingLatency(v)
+ })
+}
+
+// AddCountPingLatency adds v to the "count_ping_latency" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) AddCountPingLatency(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddCountPingLatency(v)
+ })
+}
+
+// UpdateCountPingLatency sets the "count_ping_latency" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertOne) UpdateCountPingLatency() *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateCountPingLatency()
+ })
+}
+
+// SetComputedAt sets the "computed_at" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) SetComputedAt(v time.Time) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetComputedAt(v)
+ })
+}
+
+// UpdateComputedAt sets the "computed_at" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertOne) UpdateComputedAt() *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateComputedAt()
+ })
+}
+
+// Exec executes the query.
+func (u *ChannelMonitorDailyRollupUpsertOne) Exec(ctx context.Context) error {
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for ChannelMonitorDailyRollupCreate.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *ChannelMonitorDailyRollupUpsertOne) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// Exec executes the UPSERT query and returns the inserted/updated ID.
+func (u *ChannelMonitorDailyRollupUpsertOne) ID(ctx context.Context) (id int64, err error) {
+ node, err := u.create.Save(ctx)
+ if err != nil {
+ return id, err
+ }
+ return node.ID, nil
+}
+
+// IDX is like ID, but panics if an error occurs.
+func (u *ChannelMonitorDailyRollupUpsertOne) IDX(ctx context.Context) int64 {
+ id, err := u.ID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// ChannelMonitorDailyRollupCreateBulk is the builder for creating many ChannelMonitorDailyRollup entities in bulk.
+type ChannelMonitorDailyRollupCreateBulk struct {
+ config
+ err error
+ builders []*ChannelMonitorDailyRollupCreate
+ conflict []sql.ConflictOption
+}
+
+// Save creates the ChannelMonitorDailyRollup entities in the database.
+func (_c *ChannelMonitorDailyRollupCreateBulk) Save(ctx context.Context) ([]*ChannelMonitorDailyRollup, error) {
+ if _c.err != nil {
+ return nil, _c.err
+ }
+ specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
+ nodes := make([]*ChannelMonitorDailyRollup, len(_c.builders))
+ mutators := make([]Mutator, len(_c.builders))
+ for i := range _c.builders {
+ func(i int, root context.Context) {
+ builder := _c.builders[i]
+ builder.defaults()
+ var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
+ mutation, ok := m.(*ChannelMonitorDailyRollupMutation)
+ if !ok {
+ return nil, fmt.Errorf("unexpected mutation type %T", m)
+ }
+ if err := builder.check(); err != nil {
+ return nil, err
+ }
+ builder.mutation = mutation
+ var err error
+ nodes[i], specs[i] = builder.createSpec()
+ if i < len(mutators)-1 {
+ _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
+ } else {
+ spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
+ spec.OnConflict = _c.conflict
+ // Invoke the actual operation on the latest mutation in the chain.
+ if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ }
+ }
+ if err != nil {
+ return nil, err
+ }
+ mutation.id = &nodes[i].ID
+ if specs[i].ID.Value != nil {
+ id := specs[i].ID.Value.(int64)
+ nodes[i].ID = int64(id)
+ }
+ mutation.done = true
+ return nodes[i], nil
+ })
+ for i := len(builder.hooks) - 1; i >= 0; i-- {
+ mut = builder.hooks[i](mut)
+ }
+ mutators[i] = mut
+ }(i, ctx)
+ }
+ if len(mutators) > 0 {
+ if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_c *ChannelMonitorDailyRollupCreateBulk) SaveX(ctx context.Context) []*ChannelMonitorDailyRollup {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *ChannelMonitorDailyRollupCreateBulk) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *ChannelMonitorDailyRollupCreateBulk) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.ChannelMonitorDailyRollup.CreateBulk(builders...).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.ChannelMonitorDailyRollupUpsert) {
+// SetDeletedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *ChannelMonitorDailyRollupCreateBulk) OnConflict(opts ...sql.ConflictOption) *ChannelMonitorDailyRollupUpsertBulk {
+ _c.conflict = opts
+ return &ChannelMonitorDailyRollupUpsertBulk{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.ChannelMonitorDailyRollup.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *ChannelMonitorDailyRollupCreateBulk) OnConflictColumns(columns ...string) *ChannelMonitorDailyRollupUpsertBulk {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &ChannelMonitorDailyRollupUpsertBulk{
+ create: _c,
+ }
+}
+
+// ChannelMonitorDailyRollupUpsertBulk is the builder for "upsert"-ing
+// a bulk of ChannelMonitorDailyRollup nodes.
+type ChannelMonitorDailyRollupUpsertBulk struct {
+ create *ChannelMonitorDailyRollupCreateBulk
+}
+
+// UpdateNewValues updates the mutable fields using the new values that
+// were set on create. Using this option is equivalent to using:
+//
+// client.ChannelMonitorDailyRollup.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateNewValues() *ChannelMonitorDailyRollupUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.ChannelMonitorDailyRollup.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *ChannelMonitorDailyRollupUpsertBulk) Ignore() *ChannelMonitorDailyRollupUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *ChannelMonitorDailyRollupUpsertBulk) DoNothing() *ChannelMonitorDailyRollupUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the ChannelMonitorDailyRollupCreateBulk.OnConflict
+// documentation for more info.
+func (u *ChannelMonitorDailyRollupUpsertBulk) Update(set func(*ChannelMonitorDailyRollupUpsert)) *ChannelMonitorDailyRollupUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&ChannelMonitorDailyRollupUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetDeletedAt sets the "deleted_at" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) SetDeletedAt(v time.Time) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetDeletedAt(v)
+ })
+}
+
+// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateDeletedAt() *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateDeletedAt()
+ })
+}
+
+// ClearDeletedAt clears the value of the "deleted_at" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) ClearDeletedAt() *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.ClearDeletedAt()
+ })
+}
+
+// SetMonitorID sets the "monitor_id" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) SetMonitorID(v int64) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetMonitorID(v)
+ })
+}
+
+// UpdateMonitorID sets the "monitor_id" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateMonitorID() *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateMonitorID()
+ })
+}
+
+// SetModel sets the "model" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) SetModel(v string) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetModel(v)
+ })
+}
+
+// UpdateModel sets the "model" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateModel() *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateModel()
+ })
+}
+
+// SetBucketDate sets the "bucket_date" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) SetBucketDate(v time.Time) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetBucketDate(v)
+ })
+}
+
+// UpdateBucketDate sets the "bucket_date" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateBucketDate() *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateBucketDate()
+ })
+}
+
+// SetTotalChecks sets the "total_checks" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) SetTotalChecks(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetTotalChecks(v)
+ })
+}
+
+// AddTotalChecks adds v to the "total_checks" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) AddTotalChecks(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddTotalChecks(v)
+ })
+}
+
+// UpdateTotalChecks sets the "total_checks" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateTotalChecks() *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateTotalChecks()
+ })
+}
+
+// SetOkCount sets the "ok_count" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) SetOkCount(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetOkCount(v)
+ })
+}
+
+// AddOkCount adds v to the "ok_count" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) AddOkCount(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddOkCount(v)
+ })
+}
+
+// UpdateOkCount sets the "ok_count" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateOkCount() *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateOkCount()
+ })
+}
+
+// SetOperationalCount sets the "operational_count" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) SetOperationalCount(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetOperationalCount(v)
+ })
+}
+
+// AddOperationalCount adds v to the "operational_count" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) AddOperationalCount(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddOperationalCount(v)
+ })
+}
+
+// UpdateOperationalCount sets the "operational_count" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateOperationalCount() *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateOperationalCount()
+ })
+}
+
+// SetDegradedCount sets the "degraded_count" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) SetDegradedCount(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetDegradedCount(v)
+ })
+}
+
+// AddDegradedCount adds v to the "degraded_count" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) AddDegradedCount(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddDegradedCount(v)
+ })
+}
+
+// UpdateDegradedCount sets the "degraded_count" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateDegradedCount() *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateDegradedCount()
+ })
+}
+
+// SetFailedCount sets the "failed_count" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) SetFailedCount(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetFailedCount(v)
+ })
+}
+
+// AddFailedCount adds v to the "failed_count" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) AddFailedCount(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddFailedCount(v)
+ })
+}
+
+// UpdateFailedCount sets the "failed_count" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateFailedCount() *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateFailedCount()
+ })
+}
+
+// SetErrorCount sets the "error_count" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) SetErrorCount(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetErrorCount(v)
+ })
+}
+
+// AddErrorCount adds v to the "error_count" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) AddErrorCount(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddErrorCount(v)
+ })
+}
+
+// UpdateErrorCount sets the "error_count" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateErrorCount() *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateErrorCount()
+ })
+}
+
+// SetSumLatencyMs sets the "sum_latency_ms" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) SetSumLatencyMs(v int64) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetSumLatencyMs(v)
+ })
+}
+
+// AddSumLatencyMs adds v to the "sum_latency_ms" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) AddSumLatencyMs(v int64) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddSumLatencyMs(v)
+ })
+}
+
+// UpdateSumLatencyMs sets the "sum_latency_ms" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateSumLatencyMs() *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateSumLatencyMs()
+ })
+}
+
+// SetCountLatency sets the "count_latency" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) SetCountLatency(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetCountLatency(v)
+ })
+}
+
+// AddCountLatency adds v to the "count_latency" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) AddCountLatency(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddCountLatency(v)
+ })
+}
+
+// UpdateCountLatency sets the "count_latency" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateCountLatency() *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateCountLatency()
+ })
+}
+
+// SetSumPingLatencyMs sets the "sum_ping_latency_ms" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) SetSumPingLatencyMs(v int64) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetSumPingLatencyMs(v)
+ })
+}
+
+// AddSumPingLatencyMs adds v to the "sum_ping_latency_ms" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) AddSumPingLatencyMs(v int64) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddSumPingLatencyMs(v)
+ })
+}
+
+// UpdateSumPingLatencyMs sets the "sum_ping_latency_ms" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateSumPingLatencyMs() *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateSumPingLatencyMs()
+ })
+}
+
+// SetCountPingLatency sets the "count_ping_latency" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) SetCountPingLatency(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetCountPingLatency(v)
+ })
+}
+
+// AddCountPingLatency adds v to the "count_ping_latency" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) AddCountPingLatency(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddCountPingLatency(v)
+ })
+}
+
+// UpdateCountPingLatency sets the "count_ping_latency" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateCountPingLatency() *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateCountPingLatency()
+ })
+}
+
+// SetComputedAt sets the "computed_at" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) SetComputedAt(v time.Time) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetComputedAt(v)
+ })
+}
+
+// UpdateComputedAt sets the "computed_at" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateComputedAt() *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateComputedAt()
+ })
+}
+
+// Exec executes the query.
+func (u *ChannelMonitorDailyRollupUpsertBulk) Exec(ctx context.Context) error {
+ if u.create.err != nil {
+ return u.create.err
+ }
+ for i, b := range u.create.builders {
+ if len(b.conflict) != 0 {
+ return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the ChannelMonitorDailyRollupCreateBulk instead", i)
+ }
+ }
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for ChannelMonitorDailyRollupCreateBulk.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *ChannelMonitorDailyRollupUpsertBulk) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/channelmonitordailyrollup_delete.go b/backend/ent/channelmonitordailyrollup_delete.go
new file mode 100644
index 00000000..460c94f8
--- /dev/null
+++ b/backend/ent/channelmonitordailyrollup_delete.go
@@ -0,0 +1,88 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ChannelMonitorDailyRollupDelete is the builder for deleting a ChannelMonitorDailyRollup entity.
+type ChannelMonitorDailyRollupDelete struct {
+ config
+ hooks []Hook
+ mutation *ChannelMonitorDailyRollupMutation
+}
+
+// Where appends a list predicates to the ChannelMonitorDailyRollupDelete builder.
+func (_d *ChannelMonitorDailyRollupDelete) Where(ps ...predicate.ChannelMonitorDailyRollup) *ChannelMonitorDailyRollupDelete {
+ _d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query and returns how many vertices were deleted.
+func (_d *ChannelMonitorDailyRollupDelete) Exec(ctx context.Context) (int, error) {
+ return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *ChannelMonitorDailyRollupDelete) ExecX(ctx context.Context) int {
+ n, err := _d.Exec(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return n
+}
+
+func (_d *ChannelMonitorDailyRollupDelete) sqlExec(ctx context.Context) (int, error) {
+ _spec := sqlgraph.NewDeleteSpec(channelmonitordailyrollup.Table, sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64))
+ if ps := _d.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
+ if err != nil && sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ _d.mutation.done = true
+ return affected, err
+}
+
+// ChannelMonitorDailyRollupDeleteOne is the builder for deleting a single ChannelMonitorDailyRollup entity.
+type ChannelMonitorDailyRollupDeleteOne struct {
+ _d *ChannelMonitorDailyRollupDelete
+}
+
+// Where appends a list predicates to the ChannelMonitorDailyRollupDelete builder.
+func (_d *ChannelMonitorDailyRollupDeleteOne) Where(ps ...predicate.ChannelMonitorDailyRollup) *ChannelMonitorDailyRollupDeleteOne {
+ _d._d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query.
+func (_d *ChannelMonitorDailyRollupDeleteOne) Exec(ctx context.Context) error {
+ n, err := _d._d.Exec(ctx)
+ switch {
+ case err != nil:
+ return err
+ case n == 0:
+ return &NotFoundError{channelmonitordailyrollup.Label}
+ default:
+ return nil
+ }
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *ChannelMonitorDailyRollupDeleteOne) ExecX(ctx context.Context) {
+ if err := _d.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/channelmonitordailyrollup_query.go b/backend/ent/channelmonitordailyrollup_query.go
new file mode 100644
index 00000000..30528575
--- /dev/null
+++ b/backend/ent/channelmonitordailyrollup_query.go
@@ -0,0 +1,643 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "fmt"
+ "math"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ChannelMonitorDailyRollupQuery is the builder for querying ChannelMonitorDailyRollup entities.
+type ChannelMonitorDailyRollupQuery struct {
+ config
+ ctx *QueryContext
+ order []channelmonitordailyrollup.OrderOption
+ inters []Interceptor
+ predicates []predicate.ChannelMonitorDailyRollup
+ withMonitor *ChannelMonitorQuery
+ modifiers []func(*sql.Selector)
+ // intermediate query (i.e. traversal path).
+ sql *sql.Selector
+ path func(context.Context) (*sql.Selector, error)
+}
+
+// Where adds a new predicate for the ChannelMonitorDailyRollupQuery builder.
+func (_q *ChannelMonitorDailyRollupQuery) Where(ps ...predicate.ChannelMonitorDailyRollup) *ChannelMonitorDailyRollupQuery {
+ _q.predicates = append(_q.predicates, ps...)
+ return _q
+}
+
+// Limit the number of records to be returned by this query.
+func (_q *ChannelMonitorDailyRollupQuery) Limit(limit int) *ChannelMonitorDailyRollupQuery {
+ _q.ctx.Limit = &limit
+ return _q
+}
+
+// Offset to start from.
+func (_q *ChannelMonitorDailyRollupQuery) Offset(offset int) *ChannelMonitorDailyRollupQuery {
+ _q.ctx.Offset = &offset
+ return _q
+}
+
+// Unique configures the query builder to filter duplicate records on query.
+// By default, unique is set to true, and can be disabled using this method.
+func (_q *ChannelMonitorDailyRollupQuery) Unique(unique bool) *ChannelMonitorDailyRollupQuery {
+ _q.ctx.Unique = &unique
+ return _q
+}
+
+// Order specifies how the records should be ordered.
+func (_q *ChannelMonitorDailyRollupQuery) Order(o ...channelmonitordailyrollup.OrderOption) *ChannelMonitorDailyRollupQuery {
+ _q.order = append(_q.order, o...)
+ return _q
+}
+
+// QueryMonitor chains the current query on the "monitor" edge.
+func (_q *ChannelMonitorDailyRollupQuery) QueryMonitor() *ChannelMonitorQuery {
+ query := (&ChannelMonitorClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(channelmonitordailyrollup.Table, channelmonitordailyrollup.FieldID, selector),
+ sqlgraph.To(channelmonitor.Table, channelmonitor.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, channelmonitordailyrollup.MonitorTable, channelmonitordailyrollup.MonitorColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// First returns the first ChannelMonitorDailyRollup entity from the query.
+// Returns a *NotFoundError when no ChannelMonitorDailyRollup was found.
+func (_q *ChannelMonitorDailyRollupQuery) First(ctx context.Context) (*ChannelMonitorDailyRollup, error) {
+ nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
+ if err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nil, &NotFoundError{channelmonitordailyrollup.Label}
+ }
+ return nodes[0], nil
+}
+
+// FirstX is like First, but panics if an error occurs.
+func (_q *ChannelMonitorDailyRollupQuery) FirstX(ctx context.Context) *ChannelMonitorDailyRollup {
+ node, err := _q.First(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return node
+}
+
+// FirstID returns the first ChannelMonitorDailyRollup ID from the query.
+// Returns a *NotFoundError when no ChannelMonitorDailyRollup ID was found.
+func (_q *ChannelMonitorDailyRollupQuery) FirstID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
+ return
+ }
+ if len(ids) == 0 {
+ err = &NotFoundError{channelmonitordailyrollup.Label}
+ return
+ }
+ return ids[0], nil
+}
+
+// FirstIDX is like FirstID, but panics if an error occurs.
+func (_q *ChannelMonitorDailyRollupQuery) FirstIDX(ctx context.Context) int64 {
+ id, err := _q.FirstID(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return id
+}
+
+// Only returns a single ChannelMonitorDailyRollup entity found by the query, ensuring it only returns one.
+// Returns a *NotSingularError when more than one ChannelMonitorDailyRollup entity is found.
+// Returns a *NotFoundError when no ChannelMonitorDailyRollup entities are found.
+func (_q *ChannelMonitorDailyRollupQuery) Only(ctx context.Context) (*ChannelMonitorDailyRollup, error) {
+ nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
+ if err != nil {
+ return nil, err
+ }
+ switch len(nodes) {
+ case 1:
+ return nodes[0], nil
+ case 0:
+ return nil, &NotFoundError{channelmonitordailyrollup.Label}
+ default:
+ return nil, &NotSingularError{channelmonitordailyrollup.Label}
+ }
+}
+
+// OnlyX is like Only, but panics if an error occurs.
+func (_q *ChannelMonitorDailyRollupQuery) OnlyX(ctx context.Context) *ChannelMonitorDailyRollup {
+ node, err := _q.Only(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// OnlyID is like Only, but returns the only ChannelMonitorDailyRollup ID in the query.
+// Returns a *NotSingularError when more than one ChannelMonitorDailyRollup ID is found.
+// Returns a *NotFoundError when no entities are found.
+func (_q *ChannelMonitorDailyRollupQuery) OnlyID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
+ return
+ }
+ switch len(ids) {
+ case 1:
+ id = ids[0]
+ case 0:
+ err = &NotFoundError{channelmonitordailyrollup.Label}
+ default:
+ err = &NotSingularError{channelmonitordailyrollup.Label}
+ }
+ return
+}
+
+// OnlyIDX is like OnlyID, but panics if an error occurs.
+func (_q *ChannelMonitorDailyRollupQuery) OnlyIDX(ctx context.Context) int64 {
+ id, err := _q.OnlyID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// All executes the query and returns a list of ChannelMonitorDailyRollups.
+func (_q *ChannelMonitorDailyRollupQuery) All(ctx context.Context) ([]*ChannelMonitorDailyRollup, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ qr := querierAll[[]*ChannelMonitorDailyRollup, *ChannelMonitorDailyRollupQuery]()
+ return withInterceptors[[]*ChannelMonitorDailyRollup](ctx, _q, qr, _q.inters)
+}
+
+// AllX is like All, but panics if an error occurs.
+func (_q *ChannelMonitorDailyRollupQuery) AllX(ctx context.Context) []*ChannelMonitorDailyRollup {
+ nodes, err := _q.All(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return nodes
+}
+
+// IDs executes the query and returns a list of ChannelMonitorDailyRollup IDs.
+func (_q *ChannelMonitorDailyRollupQuery) IDs(ctx context.Context) (ids []int64, err error) {
+ if _q.ctx.Unique == nil && _q.path != nil {
+ _q.Unique(true)
+ }
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
+ if err = _q.Select(channelmonitordailyrollup.FieldID).Scan(ctx, &ids); err != nil {
+ return nil, err
+ }
+ return ids, nil
+}
+
+// IDsX is like IDs, but panics if an error occurs.
+func (_q *ChannelMonitorDailyRollupQuery) IDsX(ctx context.Context) []int64 {
+ ids, err := _q.IDs(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return ids
+}
+
+// Count returns the count of the given query.
+func (_q *ChannelMonitorDailyRollupQuery) Count(ctx context.Context) (int, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return 0, err
+ }
+ return withInterceptors[int](ctx, _q, querierCount[*ChannelMonitorDailyRollupQuery](), _q.inters)
+}
+
+// CountX is like Count, but panics if an error occurs.
+func (_q *ChannelMonitorDailyRollupQuery) CountX(ctx context.Context) int {
+ count, err := _q.Count(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return count
+}
+
+// Exist returns true if the query has elements in the graph.
+func (_q *ChannelMonitorDailyRollupQuery) Exist(ctx context.Context) (bool, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
+ switch _, err := _q.FirstID(ctx); {
+ case IsNotFound(err):
+ return false, nil
+ case err != nil:
+ return false, fmt.Errorf("ent: check existence: %w", err)
+ default:
+ return true, nil
+ }
+}
+
+// ExistX is like Exist, but panics if an error occurs.
+func (_q *ChannelMonitorDailyRollupQuery) ExistX(ctx context.Context) bool {
+ exist, err := _q.Exist(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return exist
+}
+
+// Clone returns a duplicate of the ChannelMonitorDailyRollupQuery builder, including all associated steps. It can be
+// used to prepare common query builders and use them differently after the clone is made.
+func (_q *ChannelMonitorDailyRollupQuery) Clone() *ChannelMonitorDailyRollupQuery {
+ if _q == nil {
+ return nil
+ }
+ return &ChannelMonitorDailyRollupQuery{
+ config: _q.config,
+ ctx: _q.ctx.Clone(),
+ order: append([]channelmonitordailyrollup.OrderOption{}, _q.order...),
+ inters: append([]Interceptor{}, _q.inters...),
+ predicates: append([]predicate.ChannelMonitorDailyRollup{}, _q.predicates...),
+ withMonitor: _q.withMonitor.Clone(),
+ // clone intermediate query.
+ sql: _q.sql.Clone(),
+ path: _q.path,
+ }
+}
+
+// WithMonitor tells the query-builder to eager-load the nodes that are connected to
+// the "monitor" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *ChannelMonitorDailyRollupQuery) WithMonitor(opts ...func(*ChannelMonitorQuery)) *ChannelMonitorDailyRollupQuery {
+ query := (&ChannelMonitorClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withMonitor = query
+ return _q
+}
+
+// GroupBy is used to group vertices by one or more fields/columns.
+// It is often used with aggregate functions, like: count, max, mean, min, sum.
+//
+// Example:
+//
+// var v []struct {
+// DeletedAt time.Time `json:"deleted_at,omitempty"`
+// Count int `json:"count,omitempty"`
+// }
+//
+// client.ChannelMonitorDailyRollup.Query().
+// GroupBy(channelmonitordailyrollup.FieldDeletedAt).
+// Aggregate(ent.Count()).
+// Scan(ctx, &v)
+func (_q *ChannelMonitorDailyRollupQuery) GroupBy(field string, fields ...string) *ChannelMonitorDailyRollupGroupBy {
+ _q.ctx.Fields = append([]string{field}, fields...)
+ grbuild := &ChannelMonitorDailyRollupGroupBy{build: _q}
+ grbuild.flds = &_q.ctx.Fields
+ grbuild.label = channelmonitordailyrollup.Label
+ grbuild.scan = grbuild.Scan
+ return grbuild
+}
+
+// Select allows the selection one or more fields/columns for the given query,
+// instead of selecting all fields in the entity.
+//
+// Example:
+//
+// var v []struct {
+// DeletedAt time.Time `json:"deleted_at,omitempty"`
+// }
+//
+// client.ChannelMonitorDailyRollup.Query().
+// Select(channelmonitordailyrollup.FieldDeletedAt).
+// Scan(ctx, &v)
+func (_q *ChannelMonitorDailyRollupQuery) Select(fields ...string) *ChannelMonitorDailyRollupSelect {
+ _q.ctx.Fields = append(_q.ctx.Fields, fields...)
+ sbuild := &ChannelMonitorDailyRollupSelect{ChannelMonitorDailyRollupQuery: _q}
+ sbuild.label = channelmonitordailyrollup.Label
+ sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
+ return sbuild
+}
+
+// Aggregate returns a ChannelMonitorDailyRollupSelect configured with the given aggregations.
+func (_q *ChannelMonitorDailyRollupQuery) Aggregate(fns ...AggregateFunc) *ChannelMonitorDailyRollupSelect {
+ return _q.Select().Aggregate(fns...)
+}
+
+func (_q *ChannelMonitorDailyRollupQuery) prepareQuery(ctx context.Context) error {
+ for _, inter := range _q.inters {
+ if inter == nil {
+ return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
+ }
+ if trv, ok := inter.(Traverser); ok {
+ if err := trv.Traverse(ctx, _q); err != nil {
+ return err
+ }
+ }
+ }
+ for _, f := range _q.ctx.Fields {
+ if !channelmonitordailyrollup.ValidColumn(f) {
+ return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ }
+ if _q.path != nil {
+ prev, err := _q.path(ctx)
+ if err != nil {
+ return err
+ }
+ _q.sql = prev
+ }
+ return nil
+}
+
+func (_q *ChannelMonitorDailyRollupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*ChannelMonitorDailyRollup, error) {
+ var (
+ nodes = []*ChannelMonitorDailyRollup{}
+ _spec = _q.querySpec()
+ loadedTypes = [1]bool{
+ _q.withMonitor != nil,
+ }
+ )
+ _spec.ScanValues = func(columns []string) ([]any, error) {
+ return (*ChannelMonitorDailyRollup).scanValues(nil, columns)
+ }
+ _spec.Assign = func(columns []string, values []any) error {
+ node := &ChannelMonitorDailyRollup{config: _q.config}
+ nodes = append(nodes, node)
+ node.Edges.loadedTypes = loadedTypes
+ return node.assignValues(columns, values)
+ }
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ for i := range hooks {
+ hooks[i](ctx, _spec)
+ }
+ if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nodes, nil
+ }
+ if query := _q.withMonitor; query != nil {
+ if err := _q.loadMonitor(ctx, query, nodes, nil,
+ func(n *ChannelMonitorDailyRollup, e *ChannelMonitor) { n.Edges.Monitor = e }); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+func (_q *ChannelMonitorDailyRollupQuery) loadMonitor(ctx context.Context, query *ChannelMonitorQuery, nodes []*ChannelMonitorDailyRollup, init func(*ChannelMonitorDailyRollup), assign func(*ChannelMonitorDailyRollup, *ChannelMonitor)) error {
+ ids := make([]int64, 0, len(nodes))
+ nodeids := make(map[int64][]*ChannelMonitorDailyRollup)
+ for i := range nodes {
+ fk := nodes[i].MonitorID
+ if _, ok := nodeids[fk]; !ok {
+ ids = append(ids, fk)
+ }
+ nodeids[fk] = append(nodeids[fk], nodes[i])
+ }
+ if len(ids) == 0 {
+ return nil
+ }
+ query.Where(channelmonitor.IDIn(ids...))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ nodes, ok := nodeids[n.ID]
+ if !ok {
+ return fmt.Errorf(`unexpected foreign-key "monitor_id" returned %v`, n.ID)
+ }
+ for i := range nodes {
+ assign(nodes[i], n)
+ }
+ }
+ return nil
+}
+
+func (_q *ChannelMonitorDailyRollupQuery) sqlCount(ctx context.Context) (int, error) {
+ _spec := _q.querySpec()
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ _spec.Node.Columns = _q.ctx.Fields
+ if len(_q.ctx.Fields) > 0 {
+ _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
+ }
+ return sqlgraph.CountNodes(ctx, _q.driver, _spec)
+}
+
+func (_q *ChannelMonitorDailyRollupQuery) querySpec() *sqlgraph.QuerySpec {
+ _spec := sqlgraph.NewQuerySpec(channelmonitordailyrollup.Table, channelmonitordailyrollup.Columns, sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64))
+ _spec.From = _q.sql
+ if unique := _q.ctx.Unique; unique != nil {
+ _spec.Unique = *unique
+ } else if _q.path != nil {
+ _spec.Unique = true
+ }
+ if fields := _q.ctx.Fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, channelmonitordailyrollup.FieldID)
+ for i := range fields {
+ if fields[i] != channelmonitordailyrollup.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, fields[i])
+ }
+ }
+ if _q.withMonitor != nil {
+ _spec.Node.AddColumnOnce(channelmonitordailyrollup.FieldMonitorID)
+ }
+ }
+ if ps := _q.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ _spec.Limit = *limit
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ _spec.Offset = *offset
+ }
+ if ps := _q.order; len(ps) > 0 {
+ _spec.Order = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ return _spec
+}
+
+func (_q *ChannelMonitorDailyRollupQuery) sqlQuery(ctx context.Context) *sql.Selector {
+ builder := sql.Dialect(_q.driver.Dialect())
+ t1 := builder.Table(channelmonitordailyrollup.Table)
+ columns := _q.ctx.Fields
+ if len(columns) == 0 {
+ columns = channelmonitordailyrollup.Columns
+ }
+ selector := builder.Select(t1.Columns(columns...)...).From(t1)
+ if _q.sql != nil {
+ selector = _q.sql
+ selector.Select(selector.Columns(columns...)...)
+ }
+ if _q.ctx.Unique != nil && *_q.ctx.Unique {
+ selector.Distinct()
+ }
+ for _, m := range _q.modifiers {
+ m(selector)
+ }
+ for _, p := range _q.predicates {
+ p(selector)
+ }
+ for _, p := range _q.order {
+ p(selector)
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ // limit is mandatory for offset clause. We start
+ // with default value, and override it below if needed.
+ selector.Offset(*offset).Limit(math.MaxInt32)
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ selector.Limit(*limit)
+ }
+ return selector
+}
+
+// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
+// updated, deleted or "selected ... for update" by other sessions, until the transaction is
+// either committed or rolled-back.
+func (_q *ChannelMonitorDailyRollupQuery) ForUpdate(opts ...sql.LockOption) *ChannelMonitorDailyRollupQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForUpdate(opts...)
+ })
+ return _q
+}
+
+// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
+// on any rows that are read. Other sessions can read the rows, but cannot modify them
+// until your transaction commits.
+func (_q *ChannelMonitorDailyRollupQuery) ForShare(opts ...sql.LockOption) *ChannelMonitorDailyRollupQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForShare(opts...)
+ })
+ return _q
+}
+
+// ChannelMonitorDailyRollupGroupBy is the group-by builder for ChannelMonitorDailyRollup entities.
+type ChannelMonitorDailyRollupGroupBy struct {
+ selector
+ build *ChannelMonitorDailyRollupQuery
+}
+
+// Aggregate adds the given aggregation functions to the group-by query.
+func (_g *ChannelMonitorDailyRollupGroupBy) Aggregate(fns ...AggregateFunc) *ChannelMonitorDailyRollupGroupBy {
+ _g.fns = append(_g.fns, fns...)
+ return _g
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_g *ChannelMonitorDailyRollupGroupBy) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
+ if err := _g.build.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*ChannelMonitorDailyRollupQuery, *ChannelMonitorDailyRollupGroupBy](ctx, _g.build, _g, _g.build.inters, v)
+}
+
+func (_g *ChannelMonitorDailyRollupGroupBy) sqlScan(ctx context.Context, root *ChannelMonitorDailyRollupQuery, v any) error {
+ selector := root.sqlQuery(ctx).Select()
+ aggregation := make([]string, 0, len(_g.fns))
+ for _, fn := range _g.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ if len(selector.SelectedColumns()) == 0 {
+ columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
+ for _, f := range *_g.flds {
+ columns = append(columns, selector.C(f))
+ }
+ columns = append(columns, aggregation...)
+ selector.Select(columns...)
+ }
+ selector.GroupBy(selector.Columns(*_g.flds...)...)
+ if err := selector.Err(); err != nil {
+ return err
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
+
+// ChannelMonitorDailyRollupSelect is the builder for selecting fields of ChannelMonitorDailyRollup entities.
+type ChannelMonitorDailyRollupSelect struct {
+ *ChannelMonitorDailyRollupQuery
+ selector
+}
+
+// Aggregate adds the given aggregation functions to the selector query.
+func (_s *ChannelMonitorDailyRollupSelect) Aggregate(fns ...AggregateFunc) *ChannelMonitorDailyRollupSelect {
+ _s.fns = append(_s.fns, fns...)
+ return _s
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_s *ChannelMonitorDailyRollupSelect) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
+ if err := _s.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*ChannelMonitorDailyRollupQuery, *ChannelMonitorDailyRollupSelect](ctx, _s.ChannelMonitorDailyRollupQuery, _s, _s.inters, v)
+}
+
+func (_s *ChannelMonitorDailyRollupSelect) sqlScan(ctx context.Context, root *ChannelMonitorDailyRollupQuery, v any) error {
+ selector := root.sqlQuery(ctx)
+ aggregation := make([]string, 0, len(_s.fns))
+ for _, fn := range _s.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ switch n := len(*_s.selector.flds); {
+ case n == 0 && len(aggregation) > 0:
+ selector.Select(aggregation...)
+ case n != 0 && len(aggregation) > 0:
+ selector.AppendSelect(aggregation...)
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _s.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
diff --git a/backend/ent/channelmonitordailyrollup_update.go b/backend/ent/channelmonitordailyrollup_update.go
new file mode 100644
index 00000000..0b82f8bf
--- /dev/null
+++ b/backend/ent/channelmonitordailyrollup_update.go
@@ -0,0 +1,1025 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ChannelMonitorDailyRollupUpdate is the builder for updating ChannelMonitorDailyRollup entities.
+type ChannelMonitorDailyRollupUpdate struct {
+ config
+ hooks []Hook
+ mutation *ChannelMonitorDailyRollupMutation
+}
+
+// Where appends a list predicates to the ChannelMonitorDailyRollupUpdate builder.
+func (_u *ChannelMonitorDailyRollupUpdate) Where(ps ...predicate.ChannelMonitorDailyRollup) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// SetDeletedAt sets the "deleted_at" field.
+func (_u *ChannelMonitorDailyRollupUpdate) SetDeletedAt(v time.Time) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.SetDeletedAt(v)
+ return _u
+}
+
+// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdate) SetNillableDeletedAt(v *time.Time) *ChannelMonitorDailyRollupUpdate {
+ if v != nil {
+ _u.SetDeletedAt(*v)
+ }
+ return _u
+}
+
+// ClearDeletedAt clears the value of the "deleted_at" field.
+func (_u *ChannelMonitorDailyRollupUpdate) ClearDeletedAt() *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.ClearDeletedAt()
+ return _u
+}
+
+// SetMonitorID sets the "monitor_id" field.
+func (_u *ChannelMonitorDailyRollupUpdate) SetMonitorID(v int64) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.SetMonitorID(v)
+ return _u
+}
+
+// SetNillableMonitorID sets the "monitor_id" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdate) SetNillableMonitorID(v *int64) *ChannelMonitorDailyRollupUpdate {
+ if v != nil {
+ _u.SetMonitorID(*v)
+ }
+ return _u
+}
+
+// SetModel sets the "model" field.
+func (_u *ChannelMonitorDailyRollupUpdate) SetModel(v string) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.SetModel(v)
+ return _u
+}
+
+// SetNillableModel sets the "model" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdate) SetNillableModel(v *string) *ChannelMonitorDailyRollupUpdate {
+ if v != nil {
+ _u.SetModel(*v)
+ }
+ return _u
+}
+
+// SetBucketDate sets the "bucket_date" field.
+func (_u *ChannelMonitorDailyRollupUpdate) SetBucketDate(v time.Time) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.SetBucketDate(v)
+ return _u
+}
+
+// SetNillableBucketDate sets the "bucket_date" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdate) SetNillableBucketDate(v *time.Time) *ChannelMonitorDailyRollupUpdate {
+ if v != nil {
+ _u.SetBucketDate(*v)
+ }
+ return _u
+}
+
+// SetTotalChecks sets the "total_checks" field.
+func (_u *ChannelMonitorDailyRollupUpdate) SetTotalChecks(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.ResetTotalChecks()
+ _u.mutation.SetTotalChecks(v)
+ return _u
+}
+
+// SetNillableTotalChecks sets the "total_checks" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdate) SetNillableTotalChecks(v *int) *ChannelMonitorDailyRollupUpdate {
+ if v != nil {
+ _u.SetTotalChecks(*v)
+ }
+ return _u
+}
+
+// AddTotalChecks adds value to the "total_checks" field.
+func (_u *ChannelMonitorDailyRollupUpdate) AddTotalChecks(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.AddTotalChecks(v)
+ return _u
+}
+
+// SetOkCount sets the "ok_count" field.
+func (_u *ChannelMonitorDailyRollupUpdate) SetOkCount(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.ResetOkCount()
+ _u.mutation.SetOkCount(v)
+ return _u
+}
+
+// SetNillableOkCount sets the "ok_count" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdate) SetNillableOkCount(v *int) *ChannelMonitorDailyRollupUpdate {
+ if v != nil {
+ _u.SetOkCount(*v)
+ }
+ return _u
+}
+
+// AddOkCount adds value to the "ok_count" field.
+func (_u *ChannelMonitorDailyRollupUpdate) AddOkCount(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.AddOkCount(v)
+ return _u
+}
+
+// SetOperationalCount sets the "operational_count" field.
+func (_u *ChannelMonitorDailyRollupUpdate) SetOperationalCount(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.ResetOperationalCount()
+ _u.mutation.SetOperationalCount(v)
+ return _u
+}
+
+// SetNillableOperationalCount sets the "operational_count" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdate) SetNillableOperationalCount(v *int) *ChannelMonitorDailyRollupUpdate {
+ if v != nil {
+ _u.SetOperationalCount(*v)
+ }
+ return _u
+}
+
+// AddOperationalCount adds value to the "operational_count" field.
+func (_u *ChannelMonitorDailyRollupUpdate) AddOperationalCount(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.AddOperationalCount(v)
+ return _u
+}
+
+// SetDegradedCount sets the "degraded_count" field.
+func (_u *ChannelMonitorDailyRollupUpdate) SetDegradedCount(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.ResetDegradedCount()
+ _u.mutation.SetDegradedCount(v)
+ return _u
+}
+
+// SetNillableDegradedCount sets the "degraded_count" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdate) SetNillableDegradedCount(v *int) *ChannelMonitorDailyRollupUpdate {
+ if v != nil {
+ _u.SetDegradedCount(*v)
+ }
+ return _u
+}
+
+// AddDegradedCount adds value to the "degraded_count" field.
+func (_u *ChannelMonitorDailyRollupUpdate) AddDegradedCount(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.AddDegradedCount(v)
+ return _u
+}
+
+// SetFailedCount sets the "failed_count" field.
+func (_u *ChannelMonitorDailyRollupUpdate) SetFailedCount(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.ResetFailedCount()
+ _u.mutation.SetFailedCount(v)
+ return _u
+}
+
+// SetNillableFailedCount sets the "failed_count" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdate) SetNillableFailedCount(v *int) *ChannelMonitorDailyRollupUpdate {
+ if v != nil {
+ _u.SetFailedCount(*v)
+ }
+ return _u
+}
+
+// AddFailedCount adds value to the "failed_count" field.
+func (_u *ChannelMonitorDailyRollupUpdate) AddFailedCount(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.AddFailedCount(v)
+ return _u
+}
+
+// SetErrorCount sets the "error_count" field.
+func (_u *ChannelMonitorDailyRollupUpdate) SetErrorCount(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.ResetErrorCount()
+ _u.mutation.SetErrorCount(v)
+ return _u
+}
+
+// SetNillableErrorCount sets the "error_count" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdate) SetNillableErrorCount(v *int) *ChannelMonitorDailyRollupUpdate {
+ if v != nil {
+ _u.SetErrorCount(*v)
+ }
+ return _u
+}
+
+// AddErrorCount adds value to the "error_count" field.
+func (_u *ChannelMonitorDailyRollupUpdate) AddErrorCount(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.AddErrorCount(v)
+ return _u
+}
+
+// SetSumLatencyMs sets the "sum_latency_ms" field.
+func (_u *ChannelMonitorDailyRollupUpdate) SetSumLatencyMs(v int64) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.ResetSumLatencyMs()
+ _u.mutation.SetSumLatencyMs(v)
+ return _u
+}
+
+// SetNillableSumLatencyMs sets the "sum_latency_ms" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdate) SetNillableSumLatencyMs(v *int64) *ChannelMonitorDailyRollupUpdate {
+ if v != nil {
+ _u.SetSumLatencyMs(*v)
+ }
+ return _u
+}
+
+// AddSumLatencyMs adds value to the "sum_latency_ms" field.
+func (_u *ChannelMonitorDailyRollupUpdate) AddSumLatencyMs(v int64) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.AddSumLatencyMs(v)
+ return _u
+}
+
+// SetCountLatency sets the "count_latency" field.
+func (_u *ChannelMonitorDailyRollupUpdate) SetCountLatency(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.ResetCountLatency()
+ _u.mutation.SetCountLatency(v)
+ return _u
+}
+
+// SetNillableCountLatency sets the "count_latency" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdate) SetNillableCountLatency(v *int) *ChannelMonitorDailyRollupUpdate {
+ if v != nil {
+ _u.SetCountLatency(*v)
+ }
+ return _u
+}
+
+// AddCountLatency adds value to the "count_latency" field.
+func (_u *ChannelMonitorDailyRollupUpdate) AddCountLatency(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.AddCountLatency(v)
+ return _u
+}
+
+// SetSumPingLatencyMs sets the "sum_ping_latency_ms" field.
+func (_u *ChannelMonitorDailyRollupUpdate) SetSumPingLatencyMs(v int64) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.ResetSumPingLatencyMs()
+ _u.mutation.SetSumPingLatencyMs(v)
+ return _u
+}
+
+// SetNillableSumPingLatencyMs sets the "sum_ping_latency_ms" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdate) SetNillableSumPingLatencyMs(v *int64) *ChannelMonitorDailyRollupUpdate {
+ if v != nil {
+ _u.SetSumPingLatencyMs(*v)
+ }
+ return _u
+}
+
+// AddSumPingLatencyMs adds value to the "sum_ping_latency_ms" field.
+func (_u *ChannelMonitorDailyRollupUpdate) AddSumPingLatencyMs(v int64) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.AddSumPingLatencyMs(v)
+ return _u
+}
+
+// SetCountPingLatency sets the "count_ping_latency" field.
+func (_u *ChannelMonitorDailyRollupUpdate) SetCountPingLatency(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.ResetCountPingLatency()
+ _u.mutation.SetCountPingLatency(v)
+ return _u
+}
+
+// SetNillableCountPingLatency sets the "count_ping_latency" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdate) SetNillableCountPingLatency(v *int) *ChannelMonitorDailyRollupUpdate {
+ if v != nil {
+ _u.SetCountPingLatency(*v)
+ }
+ return _u
+}
+
+// AddCountPingLatency adds value to the "count_ping_latency" field.
+func (_u *ChannelMonitorDailyRollupUpdate) AddCountPingLatency(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.AddCountPingLatency(v)
+ return _u
+}
+
+// SetComputedAt sets the "computed_at" field.
+func (_u *ChannelMonitorDailyRollupUpdate) SetComputedAt(v time.Time) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.SetComputedAt(v)
+ return _u
+}
+
+// SetMonitor sets the "monitor" edge to the ChannelMonitor entity.
+func (_u *ChannelMonitorDailyRollupUpdate) SetMonitor(v *ChannelMonitor) *ChannelMonitorDailyRollupUpdate {
+ return _u.SetMonitorID(v.ID)
+}
+
+// Mutation returns the ChannelMonitorDailyRollupMutation object of the builder.
+func (_u *ChannelMonitorDailyRollupUpdate) Mutation() *ChannelMonitorDailyRollupMutation {
+ return _u.mutation
+}
+
+// ClearMonitor clears the "monitor" edge to the ChannelMonitor entity.
+func (_u *ChannelMonitorDailyRollupUpdate) ClearMonitor() *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.ClearMonitor()
+ return _u
+}
+
+// Save executes the query and returns the number of nodes affected by the update operation.
+func (_u *ChannelMonitorDailyRollupUpdate) Save(ctx context.Context) (int, error) {
+ if err := _u.defaults(); err != nil {
+ return 0, err
+ }
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *ChannelMonitorDailyRollupUpdate) SaveX(ctx context.Context) int {
+ affected, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return affected
+}
+
+// Exec executes the query.
+func (_u *ChannelMonitorDailyRollupUpdate) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *ChannelMonitorDailyRollupUpdate) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *ChannelMonitorDailyRollupUpdate) defaults() error {
+ if _, ok := _u.mutation.ComputedAt(); !ok {
+ if channelmonitordailyrollup.UpdateDefaultComputedAt == nil {
+ return fmt.Errorf("ent: uninitialized channelmonitordailyrollup.UpdateDefaultComputedAt (forgotten import ent/runtime?)")
+ }
+ v := channelmonitordailyrollup.UpdateDefaultComputedAt()
+ _u.mutation.SetComputedAt(v)
+ }
+ return nil
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *ChannelMonitorDailyRollupUpdate) check() error {
+ if v, ok := _u.mutation.Model(); ok {
+ if err := channelmonitordailyrollup.ModelValidator(v); err != nil {
+ return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorDailyRollup.model": %w`, err)}
+ }
+ }
+ if _u.mutation.MonitorCleared() && len(_u.mutation.MonitorIDs()) > 0 {
+ return errors.New(`ent: clearing a required unique edge "ChannelMonitorDailyRollup.monitor"`)
+ }
+ return nil
+}
+
+func (_u *ChannelMonitorDailyRollupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(channelmonitordailyrollup.Table, channelmonitordailyrollup.Columns, sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64))
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.DeletedAt(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldDeletedAt, field.TypeTime, value)
+ }
+ if _u.mutation.DeletedAtCleared() {
+ _spec.ClearField(channelmonitordailyrollup.FieldDeletedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.Model(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldModel, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.BucketDate(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldBucketDate, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.TotalChecks(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldTotalChecks, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedTotalChecks(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldTotalChecks, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.OkCount(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldOkCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedOkCount(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldOkCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.OperationalCount(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldOperationalCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedOperationalCount(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldOperationalCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.DegradedCount(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldDegradedCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedDegradedCount(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldDegradedCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.FailedCount(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldFailedCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedFailedCount(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldFailedCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.ErrorCount(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldErrorCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedErrorCount(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldErrorCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.SumLatencyMs(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldSumLatencyMs, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.AddedSumLatencyMs(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldSumLatencyMs, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.CountLatency(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldCountLatency, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedCountLatency(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldCountLatency, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.SumPingLatencyMs(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldSumPingLatencyMs, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.AddedSumPingLatencyMs(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldSumPingLatencyMs, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.CountPingLatency(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldCountPingLatency, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedCountPingLatency(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldCountPingLatency, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.ComputedAt(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldComputedAt, field.TypeTime, value)
+ }
+ if _u.mutation.MonitorCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: channelmonitordailyrollup.MonitorTable,
+ Columns: []string{channelmonitordailyrollup.MonitorColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.MonitorIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: channelmonitordailyrollup.MonitorTable,
+ Columns: []string{channelmonitordailyrollup.MonitorColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{channelmonitordailyrollup.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return 0, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
+
+// ChannelMonitorDailyRollupUpdateOne is the builder for updating a single ChannelMonitorDailyRollup entity.
+type ChannelMonitorDailyRollupUpdateOne struct {
+ config
+ fields []string
+ hooks []Hook
+ mutation *ChannelMonitorDailyRollupMutation
+}
+
+// SetDeletedAt sets the "deleted_at" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetDeletedAt(v time.Time) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.SetDeletedAt(v)
+ return _u
+}
+
+// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableDeletedAt(v *time.Time) *ChannelMonitorDailyRollupUpdateOne {
+ if v != nil {
+ _u.SetDeletedAt(*v)
+ }
+ return _u
+}
+
+// ClearDeletedAt clears the value of the "deleted_at" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) ClearDeletedAt() *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.ClearDeletedAt()
+ return _u
+}
+
+// SetMonitorID sets the "monitor_id" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetMonitorID(v int64) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.SetMonitorID(v)
+ return _u
+}
+
+// SetNillableMonitorID sets the "monitor_id" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableMonitorID(v *int64) *ChannelMonitorDailyRollupUpdateOne {
+ if v != nil {
+ _u.SetMonitorID(*v)
+ }
+ return _u
+}
+
+// SetModel sets the "model" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetModel(v string) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.SetModel(v)
+ return _u
+}
+
+// SetNillableModel sets the "model" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableModel(v *string) *ChannelMonitorDailyRollupUpdateOne {
+ if v != nil {
+ _u.SetModel(*v)
+ }
+ return _u
+}
+
+// SetBucketDate sets the "bucket_date" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetBucketDate(v time.Time) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.SetBucketDate(v)
+ return _u
+}
+
+// SetNillableBucketDate sets the "bucket_date" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableBucketDate(v *time.Time) *ChannelMonitorDailyRollupUpdateOne {
+ if v != nil {
+ _u.SetBucketDate(*v)
+ }
+ return _u
+}
+
+// SetTotalChecks sets the "total_checks" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetTotalChecks(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.ResetTotalChecks()
+ _u.mutation.SetTotalChecks(v)
+ return _u
+}
+
+// SetNillableTotalChecks sets the "total_checks" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableTotalChecks(v *int) *ChannelMonitorDailyRollupUpdateOne {
+ if v != nil {
+ _u.SetTotalChecks(*v)
+ }
+ return _u
+}
+
+// AddTotalChecks adds value to the "total_checks" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) AddTotalChecks(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.AddTotalChecks(v)
+ return _u
+}
+
+// SetOkCount sets the "ok_count" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetOkCount(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.ResetOkCount()
+ _u.mutation.SetOkCount(v)
+ return _u
+}
+
+// SetNillableOkCount sets the "ok_count" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableOkCount(v *int) *ChannelMonitorDailyRollupUpdateOne {
+ if v != nil {
+ _u.SetOkCount(*v)
+ }
+ return _u
+}
+
+// AddOkCount adds value to the "ok_count" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) AddOkCount(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.AddOkCount(v)
+ return _u
+}
+
+// SetOperationalCount sets the "operational_count" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetOperationalCount(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.ResetOperationalCount()
+ _u.mutation.SetOperationalCount(v)
+ return _u
+}
+
+// SetNillableOperationalCount sets the "operational_count" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableOperationalCount(v *int) *ChannelMonitorDailyRollupUpdateOne {
+ if v != nil {
+ _u.SetOperationalCount(*v)
+ }
+ return _u
+}
+
+// AddOperationalCount adds value to the "operational_count" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) AddOperationalCount(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.AddOperationalCount(v)
+ return _u
+}
+
+// SetDegradedCount sets the "degraded_count" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetDegradedCount(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.ResetDegradedCount()
+ _u.mutation.SetDegradedCount(v)
+ return _u
+}
+
+// SetNillableDegradedCount sets the "degraded_count" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableDegradedCount(v *int) *ChannelMonitorDailyRollupUpdateOne {
+ if v != nil {
+ _u.SetDegradedCount(*v)
+ }
+ return _u
+}
+
+// AddDegradedCount adds value to the "degraded_count" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) AddDegradedCount(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.AddDegradedCount(v)
+ return _u
+}
+
+// SetFailedCount sets the "failed_count" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetFailedCount(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.ResetFailedCount()
+ _u.mutation.SetFailedCount(v)
+ return _u
+}
+
+// SetNillableFailedCount sets the "failed_count" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableFailedCount(v *int) *ChannelMonitorDailyRollupUpdateOne {
+ if v != nil {
+ _u.SetFailedCount(*v)
+ }
+ return _u
+}
+
+// AddFailedCount adds value to the "failed_count" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) AddFailedCount(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.AddFailedCount(v)
+ return _u
+}
+
+// SetErrorCount sets the "error_count" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetErrorCount(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.ResetErrorCount()
+ _u.mutation.SetErrorCount(v)
+ return _u
+}
+
+// SetNillableErrorCount sets the "error_count" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableErrorCount(v *int) *ChannelMonitorDailyRollupUpdateOne {
+ if v != nil {
+ _u.SetErrorCount(*v)
+ }
+ return _u
+}
+
+// AddErrorCount adds value to the "error_count" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) AddErrorCount(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.AddErrorCount(v)
+ return _u
+}
+
+// SetSumLatencyMs sets the "sum_latency_ms" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetSumLatencyMs(v int64) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.ResetSumLatencyMs()
+ _u.mutation.SetSumLatencyMs(v)
+ return _u
+}
+
+// SetNillableSumLatencyMs sets the "sum_latency_ms" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableSumLatencyMs(v *int64) *ChannelMonitorDailyRollupUpdateOne {
+ if v != nil {
+ _u.SetSumLatencyMs(*v)
+ }
+ return _u
+}
+
+// AddSumLatencyMs adds value to the "sum_latency_ms" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) AddSumLatencyMs(v int64) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.AddSumLatencyMs(v)
+ return _u
+}
+
+// SetCountLatency sets the "count_latency" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetCountLatency(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.ResetCountLatency()
+ _u.mutation.SetCountLatency(v)
+ return _u
+}
+
+// SetNillableCountLatency sets the "count_latency" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableCountLatency(v *int) *ChannelMonitorDailyRollupUpdateOne {
+ if v != nil {
+ _u.SetCountLatency(*v)
+ }
+ return _u
+}
+
+// AddCountLatency adds value to the "count_latency" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) AddCountLatency(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.AddCountLatency(v)
+ return _u
+}
+
+// SetSumPingLatencyMs sets the "sum_ping_latency_ms" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetSumPingLatencyMs(v int64) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.ResetSumPingLatencyMs()
+ _u.mutation.SetSumPingLatencyMs(v)
+ return _u
+}
+
+// SetNillableSumPingLatencyMs sets the "sum_ping_latency_ms" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableSumPingLatencyMs(v *int64) *ChannelMonitorDailyRollupUpdateOne {
+ if v != nil {
+ _u.SetSumPingLatencyMs(*v)
+ }
+ return _u
+}
+
+// AddSumPingLatencyMs adds value to the "sum_ping_latency_ms" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) AddSumPingLatencyMs(v int64) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.AddSumPingLatencyMs(v)
+ return _u
+}
+
+// SetCountPingLatency sets the "count_ping_latency" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetCountPingLatency(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.ResetCountPingLatency()
+ _u.mutation.SetCountPingLatency(v)
+ return _u
+}
+
+// SetNillableCountPingLatency sets the "count_ping_latency" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableCountPingLatency(v *int) *ChannelMonitorDailyRollupUpdateOne {
+ if v != nil {
+ _u.SetCountPingLatency(*v)
+ }
+ return _u
+}
+
+// AddCountPingLatency adds value to the "count_ping_latency" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) AddCountPingLatency(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.AddCountPingLatency(v)
+ return _u
+}
+
+// SetComputedAt sets the "computed_at" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetComputedAt(v time.Time) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.SetComputedAt(v)
+ return _u
+}
+
+// SetMonitor sets the "monitor" edge to the ChannelMonitor entity.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetMonitor(v *ChannelMonitor) *ChannelMonitorDailyRollupUpdateOne {
+ return _u.SetMonitorID(v.ID)
+}
+
+// Mutation returns the ChannelMonitorDailyRollupMutation object of the builder.
+func (_u *ChannelMonitorDailyRollupUpdateOne) Mutation() *ChannelMonitorDailyRollupMutation {
+ return _u.mutation
+}
+
+// ClearMonitor clears the "monitor" edge to the ChannelMonitor entity.
+func (_u *ChannelMonitorDailyRollupUpdateOne) ClearMonitor() *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.ClearMonitor()
+ return _u
+}
+
+// Where appends a list predicates to the ChannelMonitorDailyRollupUpdate builder.
+func (_u *ChannelMonitorDailyRollupUpdateOne) Where(ps ...predicate.ChannelMonitorDailyRollup) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// Select allows selecting one or more fields (columns) of the returned entity.
+// The default is selecting all fields defined in the entity schema.
+func (_u *ChannelMonitorDailyRollupUpdateOne) Select(field string, fields ...string) *ChannelMonitorDailyRollupUpdateOne {
+ _u.fields = append([]string{field}, fields...)
+ return _u
+}
+
+// Save executes the query and returns the updated ChannelMonitorDailyRollup entity.
+func (_u *ChannelMonitorDailyRollupUpdateOne) Save(ctx context.Context) (*ChannelMonitorDailyRollup, error) {
+ if err := _u.defaults(); err != nil {
+ return nil, err
+ }
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SaveX(ctx context.Context) *ChannelMonitorDailyRollup {
+ node, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// Exec executes the query on the entity.
+func (_u *ChannelMonitorDailyRollupUpdateOne) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *ChannelMonitorDailyRollupUpdateOne) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *ChannelMonitorDailyRollupUpdateOne) defaults() error {
+ if _, ok := _u.mutation.ComputedAt(); !ok {
+ if channelmonitordailyrollup.UpdateDefaultComputedAt == nil {
+ return fmt.Errorf("ent: uninitialized channelmonitordailyrollup.UpdateDefaultComputedAt (forgotten import ent/runtime?)")
+ }
+ v := channelmonitordailyrollup.UpdateDefaultComputedAt()
+ _u.mutation.SetComputedAt(v)
+ }
+ return nil
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *ChannelMonitorDailyRollupUpdateOne) check() error {
+ if v, ok := _u.mutation.Model(); ok {
+ if err := channelmonitordailyrollup.ModelValidator(v); err != nil {
+ return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorDailyRollup.model": %w`, err)}
+ }
+ }
+ if _u.mutation.MonitorCleared() && len(_u.mutation.MonitorIDs()) > 0 {
+ return errors.New(`ent: clearing a required unique edge "ChannelMonitorDailyRollup.monitor"`)
+ }
+ return nil
+}
+
+func (_u *ChannelMonitorDailyRollupUpdateOne) sqlSave(ctx context.Context) (_node *ChannelMonitorDailyRollup, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(channelmonitordailyrollup.Table, channelmonitordailyrollup.Columns, sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64))
+ id, ok := _u.mutation.ID()
+ if !ok {
+ return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "ChannelMonitorDailyRollup.id" for update`)}
+ }
+ _spec.Node.ID.Value = id
+ if fields := _u.fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, channelmonitordailyrollup.FieldID)
+ for _, f := range fields {
+ if !channelmonitordailyrollup.ValidColumn(f) {
+ return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ if f != channelmonitordailyrollup.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, f)
+ }
+ }
+ }
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.DeletedAt(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldDeletedAt, field.TypeTime, value)
+ }
+ if _u.mutation.DeletedAtCleared() {
+ _spec.ClearField(channelmonitordailyrollup.FieldDeletedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.Model(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldModel, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.BucketDate(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldBucketDate, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.TotalChecks(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldTotalChecks, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedTotalChecks(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldTotalChecks, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.OkCount(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldOkCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedOkCount(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldOkCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.OperationalCount(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldOperationalCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedOperationalCount(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldOperationalCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.DegradedCount(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldDegradedCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedDegradedCount(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldDegradedCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.FailedCount(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldFailedCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedFailedCount(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldFailedCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.ErrorCount(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldErrorCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedErrorCount(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldErrorCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.SumLatencyMs(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldSumLatencyMs, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.AddedSumLatencyMs(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldSumLatencyMs, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.CountLatency(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldCountLatency, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedCountLatency(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldCountLatency, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.SumPingLatencyMs(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldSumPingLatencyMs, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.AddedSumPingLatencyMs(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldSumPingLatencyMs, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.CountPingLatency(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldCountPingLatency, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedCountPingLatency(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldCountPingLatency, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.ComputedAt(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldComputedAt, field.TypeTime, value)
+ }
+ if _u.mutation.MonitorCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: channelmonitordailyrollup.MonitorTable,
+ Columns: []string{channelmonitordailyrollup.MonitorColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.MonitorIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: channelmonitordailyrollup.MonitorTable,
+ Columns: []string{channelmonitordailyrollup.MonitorColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ _node = &ChannelMonitorDailyRollup{config: _u.config}
+ _spec.Assign = _node.assignValues
+ _spec.ScanValues = _node.scanValues
+ if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{channelmonitordailyrollup.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
diff --git a/backend/ent/channelmonitorhistory.go b/backend/ent/channelmonitorhistory.go
index 70dde542..256eaf5f 100644
--- a/backend/ent/channelmonitorhistory.go
+++ b/backend/ent/channelmonitorhistory.go
@@ -18,6 +18,8 @@ type ChannelMonitorHistory struct {
config `json:"-"`
// ID of the ent.
ID int64 `json:"id,omitempty"`
+ // DeletedAt holds the value of the "deleted_at" field.
+ DeletedAt *time.Time `json:"deleted_at,omitempty"`
// MonitorID holds the value of the "monitor_id" field.
MonitorID int64 `json:"monitor_id,omitempty"`
// Model holds the value of the "model" field.
@@ -67,7 +69,7 @@ func (*ChannelMonitorHistory) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullInt64)
case channelmonitorhistory.FieldModel, channelmonitorhistory.FieldStatus, channelmonitorhistory.FieldMessage:
values[i] = new(sql.NullString)
- case channelmonitorhistory.FieldCheckedAt:
+ case channelmonitorhistory.FieldDeletedAt, channelmonitorhistory.FieldCheckedAt:
values[i] = new(sql.NullTime)
default:
values[i] = new(sql.UnknownType)
@@ -90,6 +92,13 @@ func (_m *ChannelMonitorHistory) assignValues(columns []string, values []any) er
return fmt.Errorf("unexpected type %T for field id", value)
}
_m.ID = int64(value.Int64)
+ case channelmonitorhistory.FieldDeletedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field deleted_at", values[i])
+ } else if value.Valid {
+ _m.DeletedAt = new(time.Time)
+ *_m.DeletedAt = value.Time
+ }
case channelmonitorhistory.FieldMonitorID:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field monitor_id", values[i])
@@ -175,6 +184,11 @@ func (_m *ChannelMonitorHistory) String() string {
var builder strings.Builder
builder.WriteString("ChannelMonitorHistory(")
builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
+ if v := _m.DeletedAt; v != nil {
+ builder.WriteString("deleted_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
builder.WriteString("monitor_id=")
builder.WriteString(fmt.Sprintf("%v", _m.MonitorID))
builder.WriteString(", ")
diff --git a/backend/ent/channelmonitorhistory/channelmonitorhistory.go b/backend/ent/channelmonitorhistory/channelmonitorhistory.go
index 6a9dc006..da59791b 100644
--- a/backend/ent/channelmonitorhistory/channelmonitorhistory.go
+++ b/backend/ent/channelmonitorhistory/channelmonitorhistory.go
@@ -6,6 +6,7 @@ import (
"fmt"
"time"
+ "entgo.io/ent"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
)
@@ -15,6 +16,8 @@ const (
Label = "channel_monitor_history"
// FieldID holds the string denoting the id field in the database.
FieldID = "id"
+ // FieldDeletedAt holds the string denoting the deleted_at field in the database.
+ FieldDeletedAt = "deleted_at"
// FieldMonitorID holds the string denoting the monitor_id field in the database.
FieldMonitorID = "monitor_id"
// FieldModel holds the string denoting the model field in the database.
@@ -45,6 +48,7 @@ const (
// Columns holds all SQL columns for channelmonitorhistory fields.
var Columns = []string{
FieldID,
+ FieldDeletedAt,
FieldMonitorID,
FieldModel,
FieldStatus,
@@ -64,7 +68,14 @@ func ValidColumn(column string) bool {
return false
}
+// Note that the variables below are initialized by the runtime
+// package on the initialization of the application. Therefore,
+// it should be imported in the main as follows:
+//
+// import _ "github.com/Wei-Shaw/sub2api/ent/runtime"
var (
+ Hooks [1]ent.Hook
+ Interceptors [1]ent.Interceptor
// ModelValidator is a validator for the "model" field. It is called by the builders before save.
ModelValidator func(string) error
// DefaultMessage holds the default value on creation for the "message" field.
@@ -108,6 +119,11 @@ func ByID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldID, opts...).ToFunc()
}
+// ByDeletedAt orders the results by the deleted_at field.
+func ByDeletedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldDeletedAt, opts...).ToFunc()
+}
+
// ByMonitorID orders the results by the monitor_id field.
func ByMonitorID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldMonitorID, opts...).ToFunc()
diff --git a/backend/ent/channelmonitorhistory/where.go b/backend/ent/channelmonitorhistory/where.go
index afa73f35..7b1cd50d 100644
--- a/backend/ent/channelmonitorhistory/where.go
+++ b/backend/ent/channelmonitorhistory/where.go
@@ -55,6 +55,11 @@ func IDLTE(id int64) predicate.ChannelMonitorHistory {
return predicate.ChannelMonitorHistory(sql.FieldLTE(FieldID, id))
}
+// DeletedAt applies equality check predicate on the "deleted_at" field. It's identical to DeletedAtEQ.
+func DeletedAt(v time.Time) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldDeletedAt, v))
+}
+
// MonitorID applies equality check predicate on the "monitor_id" field. It's identical to MonitorIDEQ.
func MonitorID(v int64) predicate.ChannelMonitorHistory {
return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldMonitorID, v))
@@ -85,6 +90,56 @@ func CheckedAt(v time.Time) predicate.ChannelMonitorHistory {
return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldCheckedAt, v))
}
+// DeletedAtEQ applies the EQ predicate on the "deleted_at" field.
+func DeletedAtEQ(v time.Time) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldDeletedAt, v))
+}
+
+// DeletedAtNEQ applies the NEQ predicate on the "deleted_at" field.
+func DeletedAtNEQ(v time.Time) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldNEQ(FieldDeletedAt, v))
+}
+
+// DeletedAtIn applies the In predicate on the "deleted_at" field.
+func DeletedAtIn(vs ...time.Time) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldIn(FieldDeletedAt, vs...))
+}
+
+// DeletedAtNotIn applies the NotIn predicate on the "deleted_at" field.
+func DeletedAtNotIn(vs ...time.Time) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldNotIn(FieldDeletedAt, vs...))
+}
+
+// DeletedAtGT applies the GT predicate on the "deleted_at" field.
+func DeletedAtGT(v time.Time) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldGT(FieldDeletedAt, v))
+}
+
+// DeletedAtGTE applies the GTE predicate on the "deleted_at" field.
+func DeletedAtGTE(v time.Time) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldGTE(FieldDeletedAt, v))
+}
+
+// DeletedAtLT applies the LT predicate on the "deleted_at" field.
+func DeletedAtLT(v time.Time) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldLT(FieldDeletedAt, v))
+}
+
+// DeletedAtLTE applies the LTE predicate on the "deleted_at" field.
+func DeletedAtLTE(v time.Time) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldLTE(FieldDeletedAt, v))
+}
+
+// DeletedAtIsNil applies the IsNil predicate on the "deleted_at" field.
+func DeletedAtIsNil() predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldIsNull(FieldDeletedAt))
+}
+
+// DeletedAtNotNil applies the NotNil predicate on the "deleted_at" field.
+func DeletedAtNotNil() predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldNotNull(FieldDeletedAt))
+}
+
// MonitorIDEQ applies the EQ predicate on the "monitor_id" field.
func MonitorIDEQ(v int64) predicate.ChannelMonitorHistory {
return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldMonitorID, v))
diff --git a/backend/ent/channelmonitorhistory_create.go b/backend/ent/channelmonitorhistory_create.go
index 71034865..9a68c9ce 100644
--- a/backend/ent/channelmonitorhistory_create.go
+++ b/backend/ent/channelmonitorhistory_create.go
@@ -23,6 +23,20 @@ type ChannelMonitorHistoryCreate struct {
conflict []sql.ConflictOption
}
+// SetDeletedAt sets the "deleted_at" field.
+func (_c *ChannelMonitorHistoryCreate) SetDeletedAt(v time.Time) *ChannelMonitorHistoryCreate {
+ _c.mutation.SetDeletedAt(v)
+ return _c
+}
+
+// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil.
+func (_c *ChannelMonitorHistoryCreate) SetNillableDeletedAt(v *time.Time) *ChannelMonitorHistoryCreate {
+ if v != nil {
+ _c.SetDeletedAt(*v)
+ }
+ return _c
+}
+
// SetMonitorID sets the "monitor_id" field.
func (_c *ChannelMonitorHistoryCreate) SetMonitorID(v int64) *ChannelMonitorHistoryCreate {
_c.mutation.SetMonitorID(v)
@@ -109,7 +123,9 @@ func (_c *ChannelMonitorHistoryCreate) Mutation() *ChannelMonitorHistoryMutation
// Save creates the ChannelMonitorHistory in the database.
func (_c *ChannelMonitorHistoryCreate) Save(ctx context.Context) (*ChannelMonitorHistory, error) {
- _c.defaults()
+ if err := _c.defaults(); err != nil {
+ return nil, err
+ }
return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
}
@@ -136,15 +152,19 @@ func (_c *ChannelMonitorHistoryCreate) ExecX(ctx context.Context) {
}
// defaults sets the default values of the builder before save.
-func (_c *ChannelMonitorHistoryCreate) defaults() {
+func (_c *ChannelMonitorHistoryCreate) defaults() error {
if _, ok := _c.mutation.Message(); !ok {
v := channelmonitorhistory.DefaultMessage
_c.mutation.SetMessage(v)
}
if _, ok := _c.mutation.CheckedAt(); !ok {
+ if channelmonitorhistory.DefaultCheckedAt == nil {
+ return fmt.Errorf("ent: uninitialized channelmonitorhistory.DefaultCheckedAt (forgotten import ent/runtime?)")
+ }
v := channelmonitorhistory.DefaultCheckedAt()
_c.mutation.SetCheckedAt(v)
}
+ return nil
}
// check runs all checks and user-defined validators on the builder.
@@ -206,6 +226,10 @@ func (_c *ChannelMonitorHistoryCreate) createSpec() (*ChannelMonitorHistory, *sq
_spec = sqlgraph.NewCreateSpec(channelmonitorhistory.Table, sqlgraph.NewFieldSpec(channelmonitorhistory.FieldID, field.TypeInt64))
)
_spec.OnConflict = _c.conflict
+ if value, ok := _c.mutation.DeletedAt(); ok {
+ _spec.SetField(channelmonitorhistory.FieldDeletedAt, field.TypeTime, value)
+ _node.DeletedAt = &value
+ }
if value, ok := _c.mutation.Model(); ok {
_spec.SetField(channelmonitorhistory.FieldModel, field.TypeString, value)
_node.Model = value
@@ -254,7 +278,7 @@ func (_c *ChannelMonitorHistoryCreate) createSpec() (*ChannelMonitorHistory, *sq
// of the `INSERT` statement. For example:
//
// client.ChannelMonitorHistory.Create().
-// SetMonitorID(v).
+// SetDeletedAt(v).
// OnConflict(
// // Update the row with the new values
// // the was proposed for insertion.
@@ -263,7 +287,7 @@ func (_c *ChannelMonitorHistoryCreate) createSpec() (*ChannelMonitorHistory, *sq
// // Override some of the fields with custom
// // update values.
// Update(func(u *ent.ChannelMonitorHistoryUpsert) {
-// SetMonitorID(v+v).
+// SetDeletedAt(v+v).
// }).
// Exec(ctx)
func (_c *ChannelMonitorHistoryCreate) OnConflict(opts ...sql.ConflictOption) *ChannelMonitorHistoryUpsertOne {
@@ -299,6 +323,24 @@ type (
}
)
+// SetDeletedAt sets the "deleted_at" field.
+func (u *ChannelMonitorHistoryUpsert) SetDeletedAt(v time.Time) *ChannelMonitorHistoryUpsert {
+ u.Set(channelmonitorhistory.FieldDeletedAt, v)
+ return u
+}
+
+// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create.
+func (u *ChannelMonitorHistoryUpsert) UpdateDeletedAt() *ChannelMonitorHistoryUpsert {
+ u.SetExcluded(channelmonitorhistory.FieldDeletedAt)
+ return u
+}
+
+// ClearDeletedAt clears the value of the "deleted_at" field.
+func (u *ChannelMonitorHistoryUpsert) ClearDeletedAt() *ChannelMonitorHistoryUpsert {
+ u.SetNull(channelmonitorhistory.FieldDeletedAt)
+ return u
+}
+
// SetMonitorID sets the "monitor_id" field.
func (u *ChannelMonitorHistoryUpsert) SetMonitorID(v int64) *ChannelMonitorHistoryUpsert {
u.Set(channelmonitorhistory.FieldMonitorID, v)
@@ -453,6 +495,27 @@ func (u *ChannelMonitorHistoryUpsertOne) Update(set func(*ChannelMonitorHistoryU
return u
}
+// SetDeletedAt sets the "deleted_at" field.
+func (u *ChannelMonitorHistoryUpsertOne) SetDeletedAt(v time.Time) *ChannelMonitorHistoryUpsertOne {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.SetDeletedAt(v)
+ })
+}
+
+// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create.
+func (u *ChannelMonitorHistoryUpsertOne) UpdateDeletedAt() *ChannelMonitorHistoryUpsertOne {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.UpdateDeletedAt()
+ })
+}
+
+// ClearDeletedAt clears the value of the "deleted_at" field.
+func (u *ChannelMonitorHistoryUpsertOne) ClearDeletedAt() *ChannelMonitorHistoryUpsertOne {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.ClearDeletedAt()
+ })
+}
+
// SetMonitorID sets the "monitor_id" field.
func (u *ChannelMonitorHistoryUpsertOne) SetMonitorID(v int64) *ChannelMonitorHistoryUpsertOne {
return u.Update(func(s *ChannelMonitorHistoryUpsert) {
@@ -721,7 +784,7 @@ func (_c *ChannelMonitorHistoryCreateBulk) ExecX(ctx context.Context) {
// // Override some of the fields with custom
// // update values.
// Update(func(u *ent.ChannelMonitorHistoryUpsert) {
-// SetMonitorID(v+v).
+// SetDeletedAt(v+v).
// }).
// Exec(ctx)
func (_c *ChannelMonitorHistoryCreateBulk) OnConflict(opts ...sql.ConflictOption) *ChannelMonitorHistoryUpsertBulk {
@@ -790,6 +853,27 @@ func (u *ChannelMonitorHistoryUpsertBulk) Update(set func(*ChannelMonitorHistory
return u
}
+// SetDeletedAt sets the "deleted_at" field.
+func (u *ChannelMonitorHistoryUpsertBulk) SetDeletedAt(v time.Time) *ChannelMonitorHistoryUpsertBulk {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.SetDeletedAt(v)
+ })
+}
+
+// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create.
+func (u *ChannelMonitorHistoryUpsertBulk) UpdateDeletedAt() *ChannelMonitorHistoryUpsertBulk {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.UpdateDeletedAt()
+ })
+}
+
+// ClearDeletedAt clears the value of the "deleted_at" field.
+func (u *ChannelMonitorHistoryUpsertBulk) ClearDeletedAt() *ChannelMonitorHistoryUpsertBulk {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.ClearDeletedAt()
+ })
+}
+
// SetMonitorID sets the "monitor_id" field.
func (u *ChannelMonitorHistoryUpsertBulk) SetMonitorID(v int64) *ChannelMonitorHistoryUpsertBulk {
return u.Update(func(s *ChannelMonitorHistoryUpsert) {
diff --git a/backend/ent/channelmonitorhistory_query.go b/backend/ent/channelmonitorhistory_query.go
index 1fb872ad..26a1528f 100644
--- a/backend/ent/channelmonitorhistory_query.go
+++ b/backend/ent/channelmonitorhistory_query.go
@@ -300,12 +300,12 @@ func (_q *ChannelMonitorHistoryQuery) WithMonitor(opts ...func(*ChannelMonitorQu
// Example:
//
// var v []struct {
-// MonitorID int64 `json:"monitor_id,omitempty"`
+// DeletedAt time.Time `json:"deleted_at,omitempty"`
// Count int `json:"count,omitempty"`
// }
//
// client.ChannelMonitorHistory.Query().
-// GroupBy(channelmonitorhistory.FieldMonitorID).
+// GroupBy(channelmonitorhistory.FieldDeletedAt).
// Aggregate(ent.Count()).
// Scan(ctx, &v)
func (_q *ChannelMonitorHistoryQuery) GroupBy(field string, fields ...string) *ChannelMonitorHistoryGroupBy {
@@ -323,11 +323,11 @@ func (_q *ChannelMonitorHistoryQuery) GroupBy(field string, fields ...string) *C
// Example:
//
// var v []struct {
-// MonitorID int64 `json:"monitor_id,omitempty"`
+// DeletedAt time.Time `json:"deleted_at,omitempty"`
// }
//
// client.ChannelMonitorHistory.Query().
-// Select(channelmonitorhistory.FieldMonitorID).
+// Select(channelmonitorhistory.FieldDeletedAt).
// Scan(ctx, &v)
func (_q *ChannelMonitorHistoryQuery) Select(fields ...string) *ChannelMonitorHistorySelect {
_q.ctx.Fields = append(_q.ctx.Fields, fields...)
diff --git a/backend/ent/channelmonitorhistory_update.go b/backend/ent/channelmonitorhistory_update.go
index a85a8072..85193ec1 100644
--- a/backend/ent/channelmonitorhistory_update.go
+++ b/backend/ent/channelmonitorhistory_update.go
@@ -29,6 +29,26 @@ func (_u *ChannelMonitorHistoryUpdate) Where(ps ...predicate.ChannelMonitorHisto
return _u
}
+// SetDeletedAt sets the "deleted_at" field.
+func (_u *ChannelMonitorHistoryUpdate) SetDeletedAt(v time.Time) *ChannelMonitorHistoryUpdate {
+ _u.mutation.SetDeletedAt(v)
+ return _u
+}
+
+// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil.
+func (_u *ChannelMonitorHistoryUpdate) SetNillableDeletedAt(v *time.Time) *ChannelMonitorHistoryUpdate {
+ if v != nil {
+ _u.SetDeletedAt(*v)
+ }
+ return _u
+}
+
+// ClearDeletedAt clears the value of the "deleted_at" field.
+func (_u *ChannelMonitorHistoryUpdate) ClearDeletedAt() *ChannelMonitorHistoryUpdate {
+ _u.mutation.ClearDeletedAt()
+ return _u
+}
+
// SetMonitorID sets the "monitor_id" field.
func (_u *ChannelMonitorHistoryUpdate) SetMonitorID(v int64) *ChannelMonitorHistoryUpdate {
_u.mutation.SetMonitorID(v)
@@ -237,6 +257,12 @@ func (_u *ChannelMonitorHistoryUpdate) sqlSave(ctx context.Context) (_node int,
}
}
}
+ if value, ok := _u.mutation.DeletedAt(); ok {
+ _spec.SetField(channelmonitorhistory.FieldDeletedAt, field.TypeTime, value)
+ }
+ if _u.mutation.DeletedAtCleared() {
+ _spec.ClearField(channelmonitorhistory.FieldDeletedAt, field.TypeTime)
+ }
if value, ok := _u.mutation.Model(); ok {
_spec.SetField(channelmonitorhistory.FieldModel, field.TypeString, value)
}
@@ -319,6 +345,26 @@ type ChannelMonitorHistoryUpdateOne struct {
mutation *ChannelMonitorHistoryMutation
}
+// SetDeletedAt sets the "deleted_at" field.
+func (_u *ChannelMonitorHistoryUpdateOne) SetDeletedAt(v time.Time) *ChannelMonitorHistoryUpdateOne {
+ _u.mutation.SetDeletedAt(v)
+ return _u
+}
+
+// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil.
+func (_u *ChannelMonitorHistoryUpdateOne) SetNillableDeletedAt(v *time.Time) *ChannelMonitorHistoryUpdateOne {
+ if v != nil {
+ _u.SetDeletedAt(*v)
+ }
+ return _u
+}
+
+// ClearDeletedAt clears the value of the "deleted_at" field.
+func (_u *ChannelMonitorHistoryUpdateOne) ClearDeletedAt() *ChannelMonitorHistoryUpdateOne {
+ _u.mutation.ClearDeletedAt()
+ return _u
+}
+
// SetMonitorID sets the "monitor_id" field.
func (_u *ChannelMonitorHistoryUpdateOne) SetMonitorID(v int64) *ChannelMonitorHistoryUpdateOne {
_u.mutation.SetMonitorID(v)
@@ -557,6 +603,12 @@ func (_u *ChannelMonitorHistoryUpdateOne) sqlSave(ctx context.Context) (_node *C
}
}
}
+ if value, ok := _u.mutation.DeletedAt(); ok {
+ _spec.SetField(channelmonitorhistory.FieldDeletedAt, field.TypeTime, value)
+ }
+ if _u.mutation.DeletedAtCleared() {
+ _spec.ClearField(channelmonitorhistory.FieldDeletedAt, field.TypeTime)
+ }
if value, ok := _u.mutation.Model(); ok {
_spec.SetField(channelmonitorhistory.FieldModel, field.TypeString, value)
}
diff --git a/backend/ent/client.go b/backend/ent/client.go
index 72ef2a36..ca208094 100644
--- a/backend/ent/client.go
+++ b/backend/ent/client.go
@@ -23,6 +23,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
"github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group"
@@ -72,6 +73,8 @@ type Client struct {
AuthIdentityChannel *AuthIdentityChannelClient
// ChannelMonitor is the client for interacting with the ChannelMonitor builders.
ChannelMonitor *ChannelMonitorClient
+ // ChannelMonitorDailyRollup is the client for interacting with the ChannelMonitorDailyRollup builders.
+ ChannelMonitorDailyRollup *ChannelMonitorDailyRollupClient
// ChannelMonitorHistory is the client for interacting with the ChannelMonitorHistory builders.
ChannelMonitorHistory *ChannelMonitorHistoryClient
// ErrorPassthroughRule is the client for interacting with the ErrorPassthroughRule builders.
@@ -139,6 +142,7 @@ func (c *Client) init() {
c.AuthIdentity = NewAuthIdentityClient(c.config)
c.AuthIdentityChannel = NewAuthIdentityChannelClient(c.config)
c.ChannelMonitor = NewChannelMonitorClient(c.config)
+ c.ChannelMonitorDailyRollup = NewChannelMonitorDailyRollupClient(c.config)
c.ChannelMonitorHistory = NewChannelMonitorHistoryClient(c.config)
c.ErrorPassthroughRule = NewErrorPassthroughRuleClient(c.config)
c.Group = NewGroupClient(c.config)
@@ -253,40 +257,41 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) {
cfg := c.config
cfg.driver = tx
return &Tx{
- ctx: ctx,
- config: cfg,
- APIKey: NewAPIKeyClient(cfg),
- Account: NewAccountClient(cfg),
- AccountGroup: NewAccountGroupClient(cfg),
- Announcement: NewAnnouncementClient(cfg),
- AnnouncementRead: NewAnnouncementReadClient(cfg),
- AuthIdentity: NewAuthIdentityClient(cfg),
- AuthIdentityChannel: NewAuthIdentityChannelClient(cfg),
- ChannelMonitor: NewChannelMonitorClient(cfg),
- ChannelMonitorHistory: NewChannelMonitorHistoryClient(cfg),
- ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg),
- Group: NewGroupClient(cfg),
- IdempotencyRecord: NewIdempotencyRecordClient(cfg),
- IdentityAdoptionDecision: NewIdentityAdoptionDecisionClient(cfg),
- PaymentAuditLog: NewPaymentAuditLogClient(cfg),
- PaymentOrder: NewPaymentOrderClient(cfg),
- PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg),
- PendingAuthSession: NewPendingAuthSessionClient(cfg),
- PromoCode: NewPromoCodeClient(cfg),
- PromoCodeUsage: NewPromoCodeUsageClient(cfg),
- Proxy: NewProxyClient(cfg),
- RedeemCode: NewRedeemCodeClient(cfg),
- SecuritySecret: NewSecuritySecretClient(cfg),
- Setting: NewSettingClient(cfg),
- SubscriptionPlan: NewSubscriptionPlanClient(cfg),
- TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg),
- UsageCleanupTask: NewUsageCleanupTaskClient(cfg),
- UsageLog: NewUsageLogClient(cfg),
- User: NewUserClient(cfg),
- UserAllowedGroup: NewUserAllowedGroupClient(cfg),
- UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg),
- UserAttributeValue: NewUserAttributeValueClient(cfg),
- UserSubscription: NewUserSubscriptionClient(cfg),
+ ctx: ctx,
+ config: cfg,
+ APIKey: NewAPIKeyClient(cfg),
+ Account: NewAccountClient(cfg),
+ AccountGroup: NewAccountGroupClient(cfg),
+ Announcement: NewAnnouncementClient(cfg),
+ AnnouncementRead: NewAnnouncementReadClient(cfg),
+ AuthIdentity: NewAuthIdentityClient(cfg),
+ AuthIdentityChannel: NewAuthIdentityChannelClient(cfg),
+ ChannelMonitor: NewChannelMonitorClient(cfg),
+ ChannelMonitorDailyRollup: NewChannelMonitorDailyRollupClient(cfg),
+ ChannelMonitorHistory: NewChannelMonitorHistoryClient(cfg),
+ ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg),
+ Group: NewGroupClient(cfg),
+ IdempotencyRecord: NewIdempotencyRecordClient(cfg),
+ IdentityAdoptionDecision: NewIdentityAdoptionDecisionClient(cfg),
+ PaymentAuditLog: NewPaymentAuditLogClient(cfg),
+ PaymentOrder: NewPaymentOrderClient(cfg),
+ PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg),
+ PendingAuthSession: NewPendingAuthSessionClient(cfg),
+ PromoCode: NewPromoCodeClient(cfg),
+ PromoCodeUsage: NewPromoCodeUsageClient(cfg),
+ Proxy: NewProxyClient(cfg),
+ RedeemCode: NewRedeemCodeClient(cfg),
+ SecuritySecret: NewSecuritySecretClient(cfg),
+ Setting: NewSettingClient(cfg),
+ SubscriptionPlan: NewSubscriptionPlanClient(cfg),
+ TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg),
+ UsageCleanupTask: NewUsageCleanupTaskClient(cfg),
+ UsageLog: NewUsageLogClient(cfg),
+ User: NewUserClient(cfg),
+ UserAllowedGroup: NewUserAllowedGroupClient(cfg),
+ UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg),
+ UserAttributeValue: NewUserAttributeValueClient(cfg),
+ UserSubscription: NewUserSubscriptionClient(cfg),
}, nil
}
@@ -304,40 +309,41 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
cfg := c.config
cfg.driver = &txDriver{tx: tx, drv: c.driver}
return &Tx{
- ctx: ctx,
- config: cfg,
- APIKey: NewAPIKeyClient(cfg),
- Account: NewAccountClient(cfg),
- AccountGroup: NewAccountGroupClient(cfg),
- Announcement: NewAnnouncementClient(cfg),
- AnnouncementRead: NewAnnouncementReadClient(cfg),
- AuthIdentity: NewAuthIdentityClient(cfg),
- AuthIdentityChannel: NewAuthIdentityChannelClient(cfg),
- ChannelMonitor: NewChannelMonitorClient(cfg),
- ChannelMonitorHistory: NewChannelMonitorHistoryClient(cfg),
- ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg),
- Group: NewGroupClient(cfg),
- IdempotencyRecord: NewIdempotencyRecordClient(cfg),
- IdentityAdoptionDecision: NewIdentityAdoptionDecisionClient(cfg),
- PaymentAuditLog: NewPaymentAuditLogClient(cfg),
- PaymentOrder: NewPaymentOrderClient(cfg),
- PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg),
- PendingAuthSession: NewPendingAuthSessionClient(cfg),
- PromoCode: NewPromoCodeClient(cfg),
- PromoCodeUsage: NewPromoCodeUsageClient(cfg),
- Proxy: NewProxyClient(cfg),
- RedeemCode: NewRedeemCodeClient(cfg),
- SecuritySecret: NewSecuritySecretClient(cfg),
- Setting: NewSettingClient(cfg),
- SubscriptionPlan: NewSubscriptionPlanClient(cfg),
- TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg),
- UsageCleanupTask: NewUsageCleanupTaskClient(cfg),
- UsageLog: NewUsageLogClient(cfg),
- User: NewUserClient(cfg),
- UserAllowedGroup: NewUserAllowedGroupClient(cfg),
- UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg),
- UserAttributeValue: NewUserAttributeValueClient(cfg),
- UserSubscription: NewUserSubscriptionClient(cfg),
+ ctx: ctx,
+ config: cfg,
+ APIKey: NewAPIKeyClient(cfg),
+ Account: NewAccountClient(cfg),
+ AccountGroup: NewAccountGroupClient(cfg),
+ Announcement: NewAnnouncementClient(cfg),
+ AnnouncementRead: NewAnnouncementReadClient(cfg),
+ AuthIdentity: NewAuthIdentityClient(cfg),
+ AuthIdentityChannel: NewAuthIdentityChannelClient(cfg),
+ ChannelMonitor: NewChannelMonitorClient(cfg),
+ ChannelMonitorDailyRollup: NewChannelMonitorDailyRollupClient(cfg),
+ ChannelMonitorHistory: NewChannelMonitorHistoryClient(cfg),
+ ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg),
+ Group: NewGroupClient(cfg),
+ IdempotencyRecord: NewIdempotencyRecordClient(cfg),
+ IdentityAdoptionDecision: NewIdentityAdoptionDecisionClient(cfg),
+ PaymentAuditLog: NewPaymentAuditLogClient(cfg),
+ PaymentOrder: NewPaymentOrderClient(cfg),
+ PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg),
+ PendingAuthSession: NewPendingAuthSessionClient(cfg),
+ PromoCode: NewPromoCodeClient(cfg),
+ PromoCodeUsage: NewPromoCodeUsageClient(cfg),
+ Proxy: NewProxyClient(cfg),
+ RedeemCode: NewRedeemCodeClient(cfg),
+ SecuritySecret: NewSecuritySecretClient(cfg),
+ Setting: NewSettingClient(cfg),
+ SubscriptionPlan: NewSubscriptionPlanClient(cfg),
+ TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg),
+ UsageCleanupTask: NewUsageCleanupTaskClient(cfg),
+ UsageLog: NewUsageLogClient(cfg),
+ User: NewUserClient(cfg),
+ UserAllowedGroup: NewUserAllowedGroupClient(cfg),
+ UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg),
+ UserAttributeValue: NewUserAttributeValueClient(cfg),
+ UserSubscription: NewUserSubscriptionClient(cfg),
}, nil
}
@@ -369,12 +375,12 @@ func (c *Client) Use(hooks ...Hook) {
for _, n := range []interface{ Use(...Hook) }{
c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead,
c.AuthIdentity, c.AuthIdentityChannel, c.ChannelMonitor,
- c.ChannelMonitorHistory, c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord,
- c.IdentityAdoptionDecision, c.PaymentAuditLog, c.PaymentOrder,
- c.PaymentProviderInstance, c.PendingAuthSession, c.PromoCode, c.PromoCodeUsage,
- c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, c.SubscriptionPlan,
- c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, c.User,
- c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
+ c.ChannelMonitorDailyRollup, c.ChannelMonitorHistory, c.ErrorPassthroughRule,
+ c.Group, c.IdempotencyRecord, c.IdentityAdoptionDecision, c.PaymentAuditLog,
+ c.PaymentOrder, c.PaymentProviderInstance, c.PendingAuthSession, c.PromoCode,
+ c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting,
+ c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog,
+ c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
c.UserSubscription,
} {
n.Use(hooks...)
@@ -387,12 +393,12 @@ func (c *Client) Intercept(interceptors ...Interceptor) {
for _, n := range []interface{ Intercept(...Interceptor) }{
c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead,
c.AuthIdentity, c.AuthIdentityChannel, c.ChannelMonitor,
- c.ChannelMonitorHistory, c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord,
- c.IdentityAdoptionDecision, c.PaymentAuditLog, c.PaymentOrder,
- c.PaymentProviderInstance, c.PendingAuthSession, c.PromoCode, c.PromoCodeUsage,
- c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, c.SubscriptionPlan,
- c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, c.User,
- c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
+ c.ChannelMonitorDailyRollup, c.ChannelMonitorHistory, c.ErrorPassthroughRule,
+ c.Group, c.IdempotencyRecord, c.IdentityAdoptionDecision, c.PaymentAuditLog,
+ c.PaymentOrder, c.PaymentProviderInstance, c.PendingAuthSession, c.PromoCode,
+ c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting,
+ c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog,
+ c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
c.UserSubscription,
} {
n.Intercept(interceptors...)
@@ -418,6 +424,8 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) {
return c.AuthIdentityChannel.mutate(ctx, m)
case *ChannelMonitorMutation:
return c.ChannelMonitor.mutate(ctx, m)
+ case *ChannelMonitorDailyRollupMutation:
+ return c.ChannelMonitorDailyRollup.mutate(ctx, m)
case *ChannelMonitorHistoryMutation:
return c.ChannelMonitorHistory.mutate(ctx, m)
case *ErrorPassthroughRuleMutation:
@@ -1737,6 +1745,22 @@ func (c *ChannelMonitorClient) QueryHistory(_m *ChannelMonitor) *ChannelMonitorH
return query
}
+// QueryDailyRollups queries the daily_rollups edge of a ChannelMonitor.
+func (c *ChannelMonitorClient) QueryDailyRollups(_m *ChannelMonitor) *ChannelMonitorDailyRollupQuery {
+ query := (&ChannelMonitorDailyRollupClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(channelmonitor.Table, channelmonitor.FieldID, id),
+ sqlgraph.To(channelmonitordailyrollup.Table, channelmonitordailyrollup.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, channelmonitor.DailyRollupsTable, channelmonitor.DailyRollupsColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
// Hooks returns the client hooks.
func (c *ChannelMonitorClient) Hooks() []Hook {
return c.hooks.ChannelMonitor
@@ -1762,6 +1786,157 @@ func (c *ChannelMonitorClient) mutate(ctx context.Context, m *ChannelMonitorMuta
}
}
+// ChannelMonitorDailyRollupClient is a client for the ChannelMonitorDailyRollup schema.
+type ChannelMonitorDailyRollupClient struct {
+ config
+}
+
+// NewChannelMonitorDailyRollupClient returns a client for the ChannelMonitorDailyRollup from the given config.
+func NewChannelMonitorDailyRollupClient(c config) *ChannelMonitorDailyRollupClient {
+ return &ChannelMonitorDailyRollupClient{config: c}
+}
+
+// Use adds a list of mutation hooks to the hooks stack.
+// A call to `Use(f, g, h)` equals to `channelmonitordailyrollup.Hooks(f(g(h())))`.
+func (c *ChannelMonitorDailyRollupClient) Use(hooks ...Hook) {
+ c.hooks.ChannelMonitorDailyRollup = append(c.hooks.ChannelMonitorDailyRollup, hooks...)
+}
+
+// Intercept adds a list of query interceptors to the interceptors stack.
+// A call to `Intercept(f, g, h)` equals to `channelmonitordailyrollup.Intercept(f(g(h())))`.
+func (c *ChannelMonitorDailyRollupClient) Intercept(interceptors ...Interceptor) {
+ c.inters.ChannelMonitorDailyRollup = append(c.inters.ChannelMonitorDailyRollup, interceptors...)
+}
+
+// Create returns a builder for creating a ChannelMonitorDailyRollup entity.
+func (c *ChannelMonitorDailyRollupClient) Create() *ChannelMonitorDailyRollupCreate {
+ mutation := newChannelMonitorDailyRollupMutation(c.config, OpCreate)
+ return &ChannelMonitorDailyRollupCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// CreateBulk returns a builder for creating a bulk of ChannelMonitorDailyRollup entities.
+func (c *ChannelMonitorDailyRollupClient) CreateBulk(builders ...*ChannelMonitorDailyRollupCreate) *ChannelMonitorDailyRollupCreateBulk {
+ return &ChannelMonitorDailyRollupCreateBulk{config: c.config, builders: builders}
+}
+
+// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
+// a builder and applies setFunc on it.
+func (c *ChannelMonitorDailyRollupClient) MapCreateBulk(slice any, setFunc func(*ChannelMonitorDailyRollupCreate, int)) *ChannelMonitorDailyRollupCreateBulk {
+ rv := reflect.ValueOf(slice)
+ if rv.Kind() != reflect.Slice {
+ return &ChannelMonitorDailyRollupCreateBulk{err: fmt.Errorf("calling to ChannelMonitorDailyRollupClient.MapCreateBulk with wrong type %T, need slice", slice)}
+ }
+ builders := make([]*ChannelMonitorDailyRollupCreate, rv.Len())
+ for i := 0; i < rv.Len(); i++ {
+ builders[i] = c.Create()
+ setFunc(builders[i], i)
+ }
+ return &ChannelMonitorDailyRollupCreateBulk{config: c.config, builders: builders}
+}
+
+// Update returns an update builder for ChannelMonitorDailyRollup.
+func (c *ChannelMonitorDailyRollupClient) Update() *ChannelMonitorDailyRollupUpdate {
+ mutation := newChannelMonitorDailyRollupMutation(c.config, OpUpdate)
+ return &ChannelMonitorDailyRollupUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOne returns an update builder for the given entity.
+func (c *ChannelMonitorDailyRollupClient) UpdateOne(_m *ChannelMonitorDailyRollup) *ChannelMonitorDailyRollupUpdateOne {
+ mutation := newChannelMonitorDailyRollupMutation(c.config, OpUpdateOne, withChannelMonitorDailyRollup(_m))
+ return &ChannelMonitorDailyRollupUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOneID returns an update builder for the given id.
+func (c *ChannelMonitorDailyRollupClient) UpdateOneID(id int64) *ChannelMonitorDailyRollupUpdateOne {
+ mutation := newChannelMonitorDailyRollupMutation(c.config, OpUpdateOne, withChannelMonitorDailyRollupID(id))
+ return &ChannelMonitorDailyRollupUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// Delete returns a delete builder for ChannelMonitorDailyRollup.
+func (c *ChannelMonitorDailyRollupClient) Delete() *ChannelMonitorDailyRollupDelete {
+ mutation := newChannelMonitorDailyRollupMutation(c.config, OpDelete)
+ return &ChannelMonitorDailyRollupDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// DeleteOne returns a builder for deleting the given entity.
+func (c *ChannelMonitorDailyRollupClient) DeleteOne(_m *ChannelMonitorDailyRollup) *ChannelMonitorDailyRollupDeleteOne {
+ return c.DeleteOneID(_m.ID)
+}
+
+// DeleteOneID returns a builder for deleting the given entity by its id.
+func (c *ChannelMonitorDailyRollupClient) DeleteOneID(id int64) *ChannelMonitorDailyRollupDeleteOne {
+ builder := c.Delete().Where(channelmonitordailyrollup.ID(id))
+ builder.mutation.id = &id
+ builder.mutation.op = OpDeleteOne
+ return &ChannelMonitorDailyRollupDeleteOne{builder}
+}
+
+// Query returns a query builder for ChannelMonitorDailyRollup.
+func (c *ChannelMonitorDailyRollupClient) Query() *ChannelMonitorDailyRollupQuery {
+ return &ChannelMonitorDailyRollupQuery{
+ config: c.config,
+ ctx: &QueryContext{Type: TypeChannelMonitorDailyRollup},
+ inters: c.Interceptors(),
+ }
+}
+
+// Get returns a ChannelMonitorDailyRollup entity by its id.
+func (c *ChannelMonitorDailyRollupClient) Get(ctx context.Context, id int64) (*ChannelMonitorDailyRollup, error) {
+ return c.Query().Where(channelmonitordailyrollup.ID(id)).Only(ctx)
+}
+
+// GetX is like Get, but panics if an error occurs.
+func (c *ChannelMonitorDailyRollupClient) GetX(ctx context.Context, id int64) *ChannelMonitorDailyRollup {
+ obj, err := c.Get(ctx, id)
+ if err != nil {
+ panic(err)
+ }
+ return obj
+}
+
+// QueryMonitor queries the monitor edge of a ChannelMonitorDailyRollup.
+func (c *ChannelMonitorDailyRollupClient) QueryMonitor(_m *ChannelMonitorDailyRollup) *ChannelMonitorQuery {
+ query := (&ChannelMonitorClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(channelmonitordailyrollup.Table, channelmonitordailyrollup.FieldID, id),
+ sqlgraph.To(channelmonitor.Table, channelmonitor.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, channelmonitordailyrollup.MonitorTable, channelmonitordailyrollup.MonitorColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// Hooks returns the client hooks.
+func (c *ChannelMonitorDailyRollupClient) Hooks() []Hook {
+ hooks := c.hooks.ChannelMonitorDailyRollup
+ return append(hooks[:len(hooks):len(hooks)], channelmonitordailyrollup.Hooks[:]...)
+}
+
+// Interceptors returns the client interceptors.
+func (c *ChannelMonitorDailyRollupClient) Interceptors() []Interceptor {
+ inters := c.inters.ChannelMonitorDailyRollup
+ return append(inters[:len(inters):len(inters)], channelmonitordailyrollup.Interceptors[:]...)
+}
+
+func (c *ChannelMonitorDailyRollupClient) mutate(ctx context.Context, m *ChannelMonitorDailyRollupMutation) (Value, error) {
+ switch m.Op() {
+ case OpCreate:
+ return (&ChannelMonitorDailyRollupCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdate:
+ return (&ChannelMonitorDailyRollupUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdateOne:
+ return (&ChannelMonitorDailyRollupUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpDelete, OpDeleteOne:
+ return (&ChannelMonitorDailyRollupDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
+ default:
+ return nil, fmt.Errorf("ent: unknown ChannelMonitorDailyRollup mutation op: %q", m.Op())
+ }
+}
+
// ChannelMonitorHistoryClient is a client for the ChannelMonitorHistory schema.
type ChannelMonitorHistoryClient struct {
config
@@ -1888,12 +2063,14 @@ func (c *ChannelMonitorHistoryClient) QueryMonitor(_m *ChannelMonitorHistory) *C
// Hooks returns the client hooks.
func (c *ChannelMonitorHistoryClient) Hooks() []Hook {
- return c.hooks.ChannelMonitorHistory
+ hooks := c.hooks.ChannelMonitorHistory
+ return append(hooks[:len(hooks):len(hooks)], channelmonitorhistory.Hooks[:]...)
}
// Interceptors returns the client interceptors.
func (c *ChannelMonitorHistoryClient) Interceptors() []Interceptor {
- return c.inters.ChannelMonitorHistory
+ inters := c.inters.ChannelMonitorHistory
+ return append(inters[:len(inters):len(inters)], channelmonitorhistory.Interceptors[:]...)
}
func (c *ChannelMonitorHistoryClient) mutate(ctx context.Context, m *ChannelMonitorHistoryMutation) (Value, error) {
@@ -5671,23 +5848,23 @@ func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscription
type (
hooks struct {
APIKey, Account, AccountGroup, Announcement, AnnouncementRead, AuthIdentity,
- AuthIdentityChannel, ChannelMonitor, ChannelMonitorHistory,
- ErrorPassthroughRule, Group, IdempotencyRecord, IdentityAdoptionDecision,
- PaymentAuditLog, PaymentOrder, PaymentProviderInstance, PendingAuthSession,
- PromoCode, PromoCodeUsage, Proxy, RedeemCode, SecuritySecret, Setting,
- SubscriptionPlan, TLSFingerprintProfile, UsageCleanupTask, UsageLog, User,
- UserAllowedGroup, UserAttributeDefinition, UserAttributeValue,
- UserSubscription []ent.Hook
+ AuthIdentityChannel, ChannelMonitor, ChannelMonitorDailyRollup,
+ ChannelMonitorHistory, ErrorPassthroughRule, Group, IdempotencyRecord,
+ IdentityAdoptionDecision, PaymentAuditLog, PaymentOrder,
+ PaymentProviderInstance, PendingAuthSession, PromoCode, PromoCodeUsage, Proxy,
+ RedeemCode, SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile,
+ UsageCleanupTask, UsageLog, User, UserAllowedGroup, UserAttributeDefinition,
+ UserAttributeValue, UserSubscription []ent.Hook
}
inters struct {
APIKey, Account, AccountGroup, Announcement, AnnouncementRead, AuthIdentity,
- AuthIdentityChannel, ChannelMonitor, ChannelMonitorHistory,
- ErrorPassthroughRule, Group, IdempotencyRecord, IdentityAdoptionDecision,
- PaymentAuditLog, PaymentOrder, PaymentProviderInstance, PendingAuthSession,
- PromoCode, PromoCodeUsage, Proxy, RedeemCode, SecuritySecret, Setting,
- SubscriptionPlan, TLSFingerprintProfile, UsageCleanupTask, UsageLog, User,
- UserAllowedGroup, UserAttributeDefinition, UserAttributeValue,
- UserSubscription []ent.Interceptor
+ AuthIdentityChannel, ChannelMonitor, ChannelMonitorDailyRollup,
+ ChannelMonitorHistory, ErrorPassthroughRule, Group, IdempotencyRecord,
+ IdentityAdoptionDecision, PaymentAuditLog, PaymentOrder,
+ PaymentProviderInstance, PendingAuthSession, PromoCode, PromoCodeUsage, Proxy,
+ RedeemCode, SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile,
+ UsageCleanupTask, UsageLog, User, UserAllowedGroup, UserAttributeDefinition,
+ UserAttributeValue, UserSubscription []ent.Interceptor
}
)
diff --git a/backend/ent/ent.go b/backend/ent/ent.go
index e03ea74e..71d17624 100644
--- a/backend/ent/ent.go
+++ b/backend/ent/ent.go
@@ -20,6 +20,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
"github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group"
@@ -104,38 +105,39 @@ var (
func checkColumn(t, c string) error {
initCheck.Do(func() {
columnCheck = sql.NewColumnCheck(map[string]func(string) bool{
- apikey.Table: apikey.ValidColumn,
- account.Table: account.ValidColumn,
- accountgroup.Table: accountgroup.ValidColumn,
- announcement.Table: announcement.ValidColumn,
- announcementread.Table: announcementread.ValidColumn,
- authidentity.Table: authidentity.ValidColumn,
- authidentitychannel.Table: authidentitychannel.ValidColumn,
- channelmonitor.Table: channelmonitor.ValidColumn,
- channelmonitorhistory.Table: channelmonitorhistory.ValidColumn,
- errorpassthroughrule.Table: errorpassthroughrule.ValidColumn,
- group.Table: group.ValidColumn,
- idempotencyrecord.Table: idempotencyrecord.ValidColumn,
- identityadoptiondecision.Table: identityadoptiondecision.ValidColumn,
- paymentauditlog.Table: paymentauditlog.ValidColumn,
- paymentorder.Table: paymentorder.ValidColumn,
- paymentproviderinstance.Table: paymentproviderinstance.ValidColumn,
- pendingauthsession.Table: pendingauthsession.ValidColumn,
- promocode.Table: promocode.ValidColumn,
- promocodeusage.Table: promocodeusage.ValidColumn,
- proxy.Table: proxy.ValidColumn,
- redeemcode.Table: redeemcode.ValidColumn,
- securitysecret.Table: securitysecret.ValidColumn,
- setting.Table: setting.ValidColumn,
- subscriptionplan.Table: subscriptionplan.ValidColumn,
- tlsfingerprintprofile.Table: tlsfingerprintprofile.ValidColumn,
- usagecleanuptask.Table: usagecleanuptask.ValidColumn,
- usagelog.Table: usagelog.ValidColumn,
- user.Table: user.ValidColumn,
- userallowedgroup.Table: userallowedgroup.ValidColumn,
- userattributedefinition.Table: userattributedefinition.ValidColumn,
- userattributevalue.Table: userattributevalue.ValidColumn,
- usersubscription.Table: usersubscription.ValidColumn,
+ apikey.Table: apikey.ValidColumn,
+ account.Table: account.ValidColumn,
+ accountgroup.Table: accountgroup.ValidColumn,
+ announcement.Table: announcement.ValidColumn,
+ announcementread.Table: announcementread.ValidColumn,
+ authidentity.Table: authidentity.ValidColumn,
+ authidentitychannel.Table: authidentitychannel.ValidColumn,
+ channelmonitor.Table: channelmonitor.ValidColumn,
+ channelmonitordailyrollup.Table: channelmonitordailyrollup.ValidColumn,
+ channelmonitorhistory.Table: channelmonitorhistory.ValidColumn,
+ errorpassthroughrule.Table: errorpassthroughrule.ValidColumn,
+ group.Table: group.ValidColumn,
+ idempotencyrecord.Table: idempotencyrecord.ValidColumn,
+ identityadoptiondecision.Table: identityadoptiondecision.ValidColumn,
+ paymentauditlog.Table: paymentauditlog.ValidColumn,
+ paymentorder.Table: paymentorder.ValidColumn,
+ paymentproviderinstance.Table: paymentproviderinstance.ValidColumn,
+ pendingauthsession.Table: pendingauthsession.ValidColumn,
+ promocode.Table: promocode.ValidColumn,
+ promocodeusage.Table: promocodeusage.ValidColumn,
+ proxy.Table: proxy.ValidColumn,
+ redeemcode.Table: redeemcode.ValidColumn,
+ securitysecret.Table: securitysecret.ValidColumn,
+ setting.Table: setting.ValidColumn,
+ subscriptionplan.Table: subscriptionplan.ValidColumn,
+ tlsfingerprintprofile.Table: tlsfingerprintprofile.ValidColumn,
+ usagecleanuptask.Table: usagecleanuptask.ValidColumn,
+ usagelog.Table: usagelog.ValidColumn,
+ user.Table: user.ValidColumn,
+ userallowedgroup.Table: userallowedgroup.ValidColumn,
+ userattributedefinition.Table: userattributedefinition.ValidColumn,
+ userattributevalue.Table: userattributevalue.ValidColumn,
+ usersubscription.Table: usersubscription.ValidColumn,
})
})
return columnCheck(t, c)
diff --git a/backend/ent/hook/hook.go b/backend/ent/hook/hook.go
index e2ffec31..ff86c90d 100644
--- a/backend/ent/hook/hook.go
+++ b/backend/ent/hook/hook.go
@@ -105,6 +105,18 @@ func (f ChannelMonitorFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Val
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.ChannelMonitorMutation", m)
}
+// The ChannelMonitorDailyRollupFunc type is an adapter to allow the use of ordinary
+// function as ChannelMonitorDailyRollup mutator.
+type ChannelMonitorDailyRollupFunc func(context.Context, *ent.ChannelMonitorDailyRollupMutation) (ent.Value, error)
+
+// Mutate calls f(ctx, m).
+func (f ChannelMonitorDailyRollupFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
+ if mv, ok := m.(*ent.ChannelMonitorDailyRollupMutation); ok {
+ return f(ctx, mv)
+ }
+ return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.ChannelMonitorDailyRollupMutation", m)
+}
+
// The ChannelMonitorHistoryFunc type is an adapter to allow the use of ordinary
// function as ChannelMonitorHistory mutator.
type ChannelMonitorHistoryFunc func(context.Context, *ent.ChannelMonitorHistoryMutation) (ent.Value, error)
diff --git a/backend/ent/intercept/intercept.go b/backend/ent/intercept/intercept.go
index 1f11755b..0c83fc38 100644
--- a/backend/ent/intercept/intercept.go
+++ b/backend/ent/intercept/intercept.go
@@ -16,6 +16,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
"github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group"
@@ -315,6 +316,33 @@ func (f TraverseChannelMonitor) Traverse(ctx context.Context, q ent.Query) error
return fmt.Errorf("unexpected query type %T. expect *ent.ChannelMonitorQuery", q)
}
+// The ChannelMonitorDailyRollupFunc type is an adapter to allow the use of ordinary function as a Querier.
+type ChannelMonitorDailyRollupFunc func(context.Context, *ent.ChannelMonitorDailyRollupQuery) (ent.Value, error)
+
+// Query calls f(ctx, q).
+func (f ChannelMonitorDailyRollupFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
+ if q, ok := q.(*ent.ChannelMonitorDailyRollupQuery); ok {
+ return f(ctx, q)
+ }
+ return nil, fmt.Errorf("unexpected query type %T. expect *ent.ChannelMonitorDailyRollupQuery", q)
+}
+
+// The TraverseChannelMonitorDailyRollup type is an adapter to allow the use of ordinary function as Traverser.
+type TraverseChannelMonitorDailyRollup func(context.Context, *ent.ChannelMonitorDailyRollupQuery) error
+
+// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
+func (f TraverseChannelMonitorDailyRollup) Intercept(next ent.Querier) ent.Querier {
+ return next
+}
+
+// Traverse calls f(ctx, q).
+func (f TraverseChannelMonitorDailyRollup) Traverse(ctx context.Context, q ent.Query) error {
+ if q, ok := q.(*ent.ChannelMonitorDailyRollupQuery); ok {
+ return f(ctx, q)
+ }
+ return fmt.Errorf("unexpected query type %T. expect *ent.ChannelMonitorDailyRollupQuery", q)
+}
+
// The ChannelMonitorHistoryFunc type is an adapter to allow the use of ordinary function as a Querier.
type ChannelMonitorHistoryFunc func(context.Context, *ent.ChannelMonitorHistoryQuery) (ent.Value, error)
@@ -982,6 +1010,8 @@ func NewQuery(q ent.Query) (Query, error) {
return &query[*ent.AuthIdentityChannelQuery, predicate.AuthIdentityChannel, authidentitychannel.OrderOption]{typ: ent.TypeAuthIdentityChannel, tq: q}, nil
case *ent.ChannelMonitorQuery:
return &query[*ent.ChannelMonitorQuery, predicate.ChannelMonitor, channelmonitor.OrderOption]{typ: ent.TypeChannelMonitor, tq: q}, nil
+ case *ent.ChannelMonitorDailyRollupQuery:
+ return &query[*ent.ChannelMonitorDailyRollupQuery, predicate.ChannelMonitorDailyRollup, channelmonitordailyrollup.OrderOption]{typ: ent.TypeChannelMonitorDailyRollup, tq: q}, nil
case *ent.ChannelMonitorHistoryQuery:
return &query[*ent.ChannelMonitorHistoryQuery, predicate.ChannelMonitorHistory, channelmonitorhistory.OrderOption]{typ: ent.TypeChannelMonitorHistory, tq: q}, nil
case *ent.ErrorPassthroughRuleQuery:
diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go
index 3dc17fa2..9ce914a3 100644
--- a/backend/ent/migrate/schema.go
+++ b/backend/ent/migrate/schema.go
@@ -461,9 +461,55 @@ var (
},
},
}
+ // ChannelMonitorDailyRollupsColumns holds the columns for the "channel_monitor_daily_rollups" table.
+ ChannelMonitorDailyRollupsColumns = []*schema.Column{
+ {Name: "id", Type: field.TypeInt64, Increment: true},
+ {Name: "deleted_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "model", Type: field.TypeString, Size: 200},
+ {Name: "bucket_date", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "date"}},
+ {Name: "total_checks", Type: field.TypeInt, Default: 0},
+ {Name: "ok_count", Type: field.TypeInt, Default: 0},
+ {Name: "operational_count", Type: field.TypeInt, Default: 0},
+ {Name: "degraded_count", Type: field.TypeInt, Default: 0},
+ {Name: "failed_count", Type: field.TypeInt, Default: 0},
+ {Name: "error_count", Type: field.TypeInt, Default: 0},
+ {Name: "sum_latency_ms", Type: field.TypeInt64, Default: 0},
+ {Name: "count_latency", Type: field.TypeInt, Default: 0},
+ {Name: "sum_ping_latency_ms", Type: field.TypeInt64, Default: 0},
+ {Name: "count_ping_latency", Type: field.TypeInt, Default: 0},
+ {Name: "computed_at", Type: field.TypeTime},
+ {Name: "monitor_id", Type: field.TypeInt64},
+ }
+ // ChannelMonitorDailyRollupsTable holds the schema information for the "channel_monitor_daily_rollups" table.
+ ChannelMonitorDailyRollupsTable = &schema.Table{
+ Name: "channel_monitor_daily_rollups",
+ Columns: ChannelMonitorDailyRollupsColumns,
+ PrimaryKey: []*schema.Column{ChannelMonitorDailyRollupsColumns[0]},
+ ForeignKeys: []*schema.ForeignKey{
+ {
+ Symbol: "channel_monitor_daily_rollups_channel_monitors_daily_rollups",
+ Columns: []*schema.Column{ChannelMonitorDailyRollupsColumns[15]},
+ RefColumns: []*schema.Column{ChannelMonitorsColumns[0]},
+ OnDelete: schema.Cascade,
+ },
+ },
+ Indexes: []*schema.Index{
+ {
+ Name: "channelmonitordailyrollup_monitor_id_model_bucket_date",
+ Unique: true,
+ Columns: []*schema.Column{ChannelMonitorDailyRollupsColumns[15], ChannelMonitorDailyRollupsColumns[2], ChannelMonitorDailyRollupsColumns[3]},
+ },
+ {
+ Name: "channelmonitordailyrollup_bucket_date",
+ Unique: false,
+ Columns: []*schema.Column{ChannelMonitorDailyRollupsColumns[3]},
+ },
+ },
+ }
// ChannelMonitorHistoriesColumns holds the columns for the "channel_monitor_histories" table.
ChannelMonitorHistoriesColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
+ {Name: "deleted_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "model", Type: field.TypeString, Size: 200},
{Name: "status", Type: field.TypeEnum, Enums: []string{"operational", "degraded", "failed", "error"}},
{Name: "latency_ms", Type: field.TypeInt, Nullable: true},
@@ -480,7 +526,7 @@ var (
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "channel_monitor_histories_channel_monitors_history",
- Columns: []*schema.Column{ChannelMonitorHistoriesColumns[7]},
+ Columns: []*schema.Column{ChannelMonitorHistoriesColumns[8]},
RefColumns: []*schema.Column{ChannelMonitorsColumns[0]},
OnDelete: schema.Cascade,
},
@@ -489,12 +535,12 @@ var (
{
Name: "channelmonitorhistory_monitor_id_model_checked_at",
Unique: false,
- Columns: []*schema.Column{ChannelMonitorHistoriesColumns[7], ChannelMonitorHistoriesColumns[1], ChannelMonitorHistoriesColumns[6]},
+ Columns: []*schema.Column{ChannelMonitorHistoriesColumns[8], ChannelMonitorHistoriesColumns[2], ChannelMonitorHistoriesColumns[7]},
},
{
Name: "channelmonitorhistory_checked_at",
Unique: false,
- Columns: []*schema.Column{ChannelMonitorHistoriesColumns[6]},
+ Columns: []*schema.Column{ChannelMonitorHistoriesColumns[7]},
},
},
}
@@ -1598,6 +1644,7 @@ var (
AuthIdentitiesTable,
AuthIdentityChannelsTable,
ChannelMonitorsTable,
+ ChannelMonitorDailyRollupsTable,
ChannelMonitorHistoriesTable,
ErrorPassthroughRulesTable,
GroupsTable,
@@ -1659,6 +1706,10 @@ func init() {
ChannelMonitorsTable.Annotation = &entsql.Annotation{
Table: "channel_monitors",
}
+ ChannelMonitorDailyRollupsTable.ForeignKeys[0].RefTable = ChannelMonitorsTable
+ ChannelMonitorDailyRollupsTable.Annotation = &entsql.Annotation{
+ Table: "channel_monitor_daily_rollups",
+ }
ChannelMonitorHistoriesTable.ForeignKeys[0].RefTable = ChannelMonitorsTable
ChannelMonitorHistoriesTable.Annotation = &entsql.Annotation{
Table: "channel_monitor_histories",
diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go
index 528ace5f..e97456fe 100644
--- a/backend/ent/mutation.go
+++ b/backend/ent/mutation.go
@@ -20,6 +20,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
"github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group"
@@ -57,38 +58,39 @@ const (
OpUpdateOne = ent.OpUpdateOne
// Node types.
- TypeAPIKey = "APIKey"
- TypeAccount = "Account"
- TypeAccountGroup = "AccountGroup"
- TypeAnnouncement = "Announcement"
- TypeAnnouncementRead = "AnnouncementRead"
- TypeAuthIdentity = "AuthIdentity"
- TypeAuthIdentityChannel = "AuthIdentityChannel"
- TypeChannelMonitor = "ChannelMonitor"
- TypeChannelMonitorHistory = "ChannelMonitorHistory"
- TypeErrorPassthroughRule = "ErrorPassthroughRule"
- TypeGroup = "Group"
- TypeIdempotencyRecord = "IdempotencyRecord"
- TypeIdentityAdoptionDecision = "IdentityAdoptionDecision"
- TypePaymentAuditLog = "PaymentAuditLog"
- TypePaymentOrder = "PaymentOrder"
- TypePaymentProviderInstance = "PaymentProviderInstance"
- TypePendingAuthSession = "PendingAuthSession"
- TypePromoCode = "PromoCode"
- TypePromoCodeUsage = "PromoCodeUsage"
- TypeProxy = "Proxy"
- TypeRedeemCode = "RedeemCode"
- TypeSecuritySecret = "SecuritySecret"
- TypeSetting = "Setting"
- TypeSubscriptionPlan = "SubscriptionPlan"
- TypeTLSFingerprintProfile = "TLSFingerprintProfile"
- TypeUsageCleanupTask = "UsageCleanupTask"
- TypeUsageLog = "UsageLog"
- TypeUser = "User"
- TypeUserAllowedGroup = "UserAllowedGroup"
- TypeUserAttributeDefinition = "UserAttributeDefinition"
- TypeUserAttributeValue = "UserAttributeValue"
- TypeUserSubscription = "UserSubscription"
+ TypeAPIKey = "APIKey"
+ TypeAccount = "Account"
+ TypeAccountGroup = "AccountGroup"
+ TypeAnnouncement = "Announcement"
+ TypeAnnouncementRead = "AnnouncementRead"
+ TypeAuthIdentity = "AuthIdentity"
+ TypeAuthIdentityChannel = "AuthIdentityChannel"
+ TypeChannelMonitor = "ChannelMonitor"
+ TypeChannelMonitorDailyRollup = "ChannelMonitorDailyRollup"
+ TypeChannelMonitorHistory = "ChannelMonitorHistory"
+ TypeErrorPassthroughRule = "ErrorPassthroughRule"
+ TypeGroup = "Group"
+ TypeIdempotencyRecord = "IdempotencyRecord"
+ TypeIdentityAdoptionDecision = "IdentityAdoptionDecision"
+ TypePaymentAuditLog = "PaymentAuditLog"
+ TypePaymentOrder = "PaymentOrder"
+ TypePaymentProviderInstance = "PaymentProviderInstance"
+ TypePendingAuthSession = "PendingAuthSession"
+ TypePromoCode = "PromoCode"
+ TypePromoCodeUsage = "PromoCodeUsage"
+ TypeProxy = "Proxy"
+ TypeRedeemCode = "RedeemCode"
+ TypeSecuritySecret = "SecuritySecret"
+ TypeSetting = "Setting"
+ TypeSubscriptionPlan = "SubscriptionPlan"
+ TypeTLSFingerprintProfile = "TLSFingerprintProfile"
+ TypeUsageCleanupTask = "UsageCleanupTask"
+ TypeUsageLog = "UsageLog"
+ TypeUser = "User"
+ TypeUserAllowedGroup = "UserAllowedGroup"
+ TypeUserAttributeDefinition = "UserAttributeDefinition"
+ TypeUserAttributeValue = "UserAttributeValue"
+ TypeUserSubscription = "UserSubscription"
)
// APIKeyMutation represents an operation that mutates the APIKey nodes in the graph.
@@ -8741,32 +8743,35 @@ func (m *AuthIdentityChannelMutation) ResetEdge(name string) error {
// ChannelMonitorMutation represents an operation that mutates the ChannelMonitor nodes in the graph.
type ChannelMonitorMutation struct {
config
- op Op
- typ string
- id *int64
- created_at *time.Time
- updated_at *time.Time
- name *string
- provider *channelmonitor.Provider
- endpoint *string
- api_key_encrypted *string
- primary_model *string
- extra_models *[]string
- appendextra_models []string
- group_name *string
- enabled *bool
- interval_seconds *int
- addinterval_seconds *int
- last_checked_at *time.Time
- created_by *int64
- addcreated_by *int64
- clearedFields map[string]struct{}
- history map[int64]struct{}
- removedhistory map[int64]struct{}
- clearedhistory bool
- done bool
- oldValue func(context.Context) (*ChannelMonitor, error)
- predicates []predicate.ChannelMonitor
+ op Op
+ typ string
+ id *int64
+ created_at *time.Time
+ updated_at *time.Time
+ name *string
+ provider *channelmonitor.Provider
+ endpoint *string
+ api_key_encrypted *string
+ primary_model *string
+ extra_models *[]string
+ appendextra_models []string
+ group_name *string
+ enabled *bool
+ interval_seconds *int
+ addinterval_seconds *int
+ last_checked_at *time.Time
+ created_by *int64
+ addcreated_by *int64
+ clearedFields map[string]struct{}
+ history map[int64]struct{}
+ removedhistory map[int64]struct{}
+ clearedhistory bool
+ daily_rollups map[int64]struct{}
+ removeddaily_rollups map[int64]struct{}
+ cleareddaily_rollups bool
+ done bool
+ oldValue func(context.Context) (*ChannelMonitor, error)
+ predicates []predicate.ChannelMonitor
}
var _ ent.Mutation = (*ChannelMonitorMutation)(nil)
@@ -9470,6 +9475,60 @@ func (m *ChannelMonitorMutation) ResetHistory() {
m.removedhistory = nil
}
+// AddDailyRollupIDs adds the "daily_rollups" edge to the ChannelMonitorDailyRollup entity by ids.
+func (m *ChannelMonitorMutation) AddDailyRollupIDs(ids ...int64) {
+ if m.daily_rollups == nil {
+ m.daily_rollups = make(map[int64]struct{})
+ }
+ for i := range ids {
+ m.daily_rollups[ids[i]] = struct{}{}
+ }
+}
+
+// ClearDailyRollups clears the "daily_rollups" edge to the ChannelMonitorDailyRollup entity.
+func (m *ChannelMonitorMutation) ClearDailyRollups() {
+ m.cleareddaily_rollups = true
+}
+
+// DailyRollupsCleared reports if the "daily_rollups" edge to the ChannelMonitorDailyRollup entity was cleared.
+func (m *ChannelMonitorMutation) DailyRollupsCleared() bool {
+ return m.cleareddaily_rollups
+}
+
+// RemoveDailyRollupIDs removes the "daily_rollups" edge to the ChannelMonitorDailyRollup entity by IDs.
+func (m *ChannelMonitorMutation) RemoveDailyRollupIDs(ids ...int64) {
+ if m.removeddaily_rollups == nil {
+ m.removeddaily_rollups = make(map[int64]struct{})
+ }
+ for i := range ids {
+ delete(m.daily_rollups, ids[i])
+ m.removeddaily_rollups[ids[i]] = struct{}{}
+ }
+}
+
+// RemovedDailyRollups returns the removed IDs of the "daily_rollups" edge to the ChannelMonitorDailyRollup entity.
+func (m *ChannelMonitorMutation) RemovedDailyRollupsIDs() (ids []int64) {
+ for id := range m.removeddaily_rollups {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// DailyRollupsIDs returns the "daily_rollups" edge IDs in the mutation.
+func (m *ChannelMonitorMutation) DailyRollupsIDs() (ids []int64) {
+ for id := range m.daily_rollups {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// ResetDailyRollups resets all changes to the "daily_rollups" edge.
+func (m *ChannelMonitorMutation) ResetDailyRollups() {
+ m.daily_rollups = nil
+ m.cleareddaily_rollups = false
+ m.removeddaily_rollups = nil
+}
+
// Where appends a list predicates to the ChannelMonitorMutation builder.
func (m *ChannelMonitorMutation) Where(ps ...predicate.ChannelMonitor) {
m.predicates = append(m.predicates, ps...)
@@ -9849,10 +9908,13 @@ func (m *ChannelMonitorMutation) ResetField(name string) error {
// AddedEdges returns all edge names that were set/added in this mutation.
func (m *ChannelMonitorMutation) AddedEdges() []string {
- edges := make([]string, 0, 1)
+ edges := make([]string, 0, 2)
if m.history != nil {
edges = append(edges, channelmonitor.EdgeHistory)
}
+ if m.daily_rollups != nil {
+ edges = append(edges, channelmonitor.EdgeDailyRollups)
+ }
return edges
}
@@ -9866,16 +9928,25 @@ func (m *ChannelMonitorMutation) AddedIDs(name string) []ent.Value {
ids = append(ids, id)
}
return ids
+ case channelmonitor.EdgeDailyRollups:
+ ids := make([]ent.Value, 0, len(m.daily_rollups))
+ for id := range m.daily_rollups {
+ ids = append(ids, id)
+ }
+ return ids
}
return nil
}
// RemovedEdges returns all edge names that were removed in this mutation.
func (m *ChannelMonitorMutation) RemovedEdges() []string {
- edges := make([]string, 0, 1)
+ edges := make([]string, 0, 2)
if m.removedhistory != nil {
edges = append(edges, channelmonitor.EdgeHistory)
}
+ if m.removeddaily_rollups != nil {
+ edges = append(edges, channelmonitor.EdgeDailyRollups)
+ }
return edges
}
@@ -9889,16 +9960,25 @@ func (m *ChannelMonitorMutation) RemovedIDs(name string) []ent.Value {
ids = append(ids, id)
}
return ids
+ case channelmonitor.EdgeDailyRollups:
+ ids := make([]ent.Value, 0, len(m.removeddaily_rollups))
+ for id := range m.removeddaily_rollups {
+ ids = append(ids, id)
+ }
+ return ids
}
return nil
}
// ClearedEdges returns all edge names that were cleared in this mutation.
func (m *ChannelMonitorMutation) ClearedEdges() []string {
- edges := make([]string, 0, 1)
+ edges := make([]string, 0, 2)
if m.clearedhistory {
edges = append(edges, channelmonitor.EdgeHistory)
}
+ if m.cleareddaily_rollups {
+ edges = append(edges, channelmonitor.EdgeDailyRollups)
+ }
return edges
}
@@ -9908,6 +9988,8 @@ func (m *ChannelMonitorMutation) EdgeCleared(name string) bool {
switch name {
case channelmonitor.EdgeHistory:
return m.clearedhistory
+ case channelmonitor.EdgeDailyRollups:
+ return m.cleareddaily_rollups
}
return false
}
@@ -9927,43 +10009,62 @@ func (m *ChannelMonitorMutation) ResetEdge(name string) error {
case channelmonitor.EdgeHistory:
m.ResetHistory()
return nil
+ case channelmonitor.EdgeDailyRollups:
+ m.ResetDailyRollups()
+ return nil
}
return fmt.Errorf("unknown ChannelMonitor edge %s", name)
}
-// ChannelMonitorHistoryMutation represents an operation that mutates the ChannelMonitorHistory nodes in the graph.
-type ChannelMonitorHistoryMutation struct {
+// ChannelMonitorDailyRollupMutation represents an operation that mutates the ChannelMonitorDailyRollup nodes in the graph.
+type ChannelMonitorDailyRollupMutation struct {
config
- op Op
- typ string
- id *int64
- model *string
- status *channelmonitorhistory.Status
- latency_ms *int
- addlatency_ms *int
- ping_latency_ms *int
- addping_latency_ms *int
- message *string
- checked_at *time.Time
- clearedFields map[string]struct{}
- monitor *int64
- clearedmonitor bool
- done bool
- oldValue func(context.Context) (*ChannelMonitorHistory, error)
- predicates []predicate.ChannelMonitorHistory
-}
-
-var _ ent.Mutation = (*ChannelMonitorHistoryMutation)(nil)
-
-// channelmonitorhistoryOption allows management of the mutation configuration using functional options.
-type channelmonitorhistoryOption func(*ChannelMonitorHistoryMutation)
-
-// newChannelMonitorHistoryMutation creates new mutation for the ChannelMonitorHistory entity.
-func newChannelMonitorHistoryMutation(c config, op Op, opts ...channelmonitorhistoryOption) *ChannelMonitorHistoryMutation {
- m := &ChannelMonitorHistoryMutation{
+ op Op
+ typ string
+ id *int64
+ deleted_at *time.Time
+ model *string
+ bucket_date *time.Time
+ total_checks *int
+ addtotal_checks *int
+ ok_count *int
+ addok_count *int
+ operational_count *int
+ addoperational_count *int
+ degraded_count *int
+ adddegraded_count *int
+ failed_count *int
+ addfailed_count *int
+ error_count *int
+ adderror_count *int
+ sum_latency_ms *int64
+ addsum_latency_ms *int64
+ count_latency *int
+ addcount_latency *int
+ sum_ping_latency_ms *int64
+ addsum_ping_latency_ms *int64
+ count_ping_latency *int
+ addcount_ping_latency *int
+ computed_at *time.Time
+ clearedFields map[string]struct{}
+ monitor *int64
+ clearedmonitor bool
+ done bool
+ oldValue func(context.Context) (*ChannelMonitorDailyRollup, error)
+ predicates []predicate.ChannelMonitorDailyRollup
+}
+
+var _ ent.Mutation = (*ChannelMonitorDailyRollupMutation)(nil)
+
+// channelmonitordailyrollupOption allows management of the mutation configuration using functional options.
+type channelmonitordailyrollupOption func(*ChannelMonitorDailyRollupMutation)
+
+// newChannelMonitorDailyRollupMutation creates new mutation for the ChannelMonitorDailyRollup entity.
+func newChannelMonitorDailyRollupMutation(c config, op Op, opts ...channelmonitordailyrollupOption) *ChannelMonitorDailyRollupMutation {
+ m := &ChannelMonitorDailyRollupMutation{
config: c,
op: op,
- typ: TypeChannelMonitorHistory,
+ typ: TypeChannelMonitorDailyRollup,
clearedFields: make(map[string]struct{}),
}
for _, opt := range opts {
@@ -9972,20 +10073,20 @@ func newChannelMonitorHistoryMutation(c config, op Op, opts ...channelmonitorhis
return m
}
-// withChannelMonitorHistoryID sets the ID field of the mutation.
-func withChannelMonitorHistoryID(id int64) channelmonitorhistoryOption {
- return func(m *ChannelMonitorHistoryMutation) {
+// withChannelMonitorDailyRollupID sets the ID field of the mutation.
+func withChannelMonitorDailyRollupID(id int64) channelmonitordailyrollupOption {
+ return func(m *ChannelMonitorDailyRollupMutation) {
var (
err error
once sync.Once
- value *ChannelMonitorHistory
+ value *ChannelMonitorDailyRollup
)
- m.oldValue = func(ctx context.Context) (*ChannelMonitorHistory, error) {
+ m.oldValue = func(ctx context.Context) (*ChannelMonitorDailyRollup, error) {
once.Do(func() {
if m.done {
err = errors.New("querying old values post mutation is not allowed")
} else {
- value, err = m.Client().ChannelMonitorHistory.Get(ctx, id)
+ value, err = m.Client().ChannelMonitorDailyRollup.Get(ctx, id)
}
})
return value, err
@@ -9994,10 +10095,10 @@ func withChannelMonitorHistoryID(id int64) channelmonitorhistoryOption {
}
}
-// withChannelMonitorHistory sets the old ChannelMonitorHistory of the mutation.
-func withChannelMonitorHistory(node *ChannelMonitorHistory) channelmonitorhistoryOption {
- return func(m *ChannelMonitorHistoryMutation) {
- m.oldValue = func(context.Context) (*ChannelMonitorHistory, error) {
+// withChannelMonitorDailyRollup sets the old ChannelMonitorDailyRollup of the mutation.
+func withChannelMonitorDailyRollup(node *ChannelMonitorDailyRollup) channelmonitordailyrollupOption {
+ return func(m *ChannelMonitorDailyRollupMutation) {
+ m.oldValue = func(context.Context) (*ChannelMonitorDailyRollup, error) {
return node, nil
}
m.id = &node.ID
@@ -10006,7 +10107,7 @@ func withChannelMonitorHistory(node *ChannelMonitorHistory) channelmonitorhistor
// Client returns a new `ent.Client` from the mutation. If the mutation was
// executed in a transaction (ent.Tx), a transactional client is returned.
-func (m ChannelMonitorHistoryMutation) Client() *Client {
+func (m ChannelMonitorDailyRollupMutation) Client() *Client {
client := &Client{config: m.config}
client.init()
return client
@@ -10014,7 +10115,7 @@ func (m ChannelMonitorHistoryMutation) Client() *Client {
// Tx returns an `ent.Tx` for mutations that were executed in transactions;
// it returns an error otherwise.
-func (m ChannelMonitorHistoryMutation) Tx() (*Tx, error) {
+func (m ChannelMonitorDailyRollupMutation) Tx() (*Tx, error) {
if _, ok := m.driver.(*txDriver); !ok {
return nil, errors.New("ent: mutation is not running in a transaction")
}
@@ -10025,7 +10126,7 @@ func (m ChannelMonitorHistoryMutation) Tx() (*Tx, error) {
// ID returns the ID value in the mutation. Note that the ID is only available
// if it was provided to the builder or after it was returned from the database.
-func (m *ChannelMonitorHistoryMutation) ID() (id int64, exists bool) {
+func (m *ChannelMonitorDailyRollupMutation) ID() (id int64, exists bool) {
if m.id == nil {
return
}
@@ -10036,7 +10137,7 @@ func (m *ChannelMonitorHistoryMutation) ID() (id int64, exists bool) {
// That means, if the mutation is applied within a transaction with an isolation level such
// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
// or updated by the mutation.
-func (m *ChannelMonitorHistoryMutation) IDs(ctx context.Context) ([]int64, error) {
+func (m *ChannelMonitorDailyRollupMutation) IDs(ctx context.Context) ([]int64, error) {
switch {
case m.op.Is(OpUpdateOne | OpDeleteOne):
id, exists := m.ID()
@@ -10045,19 +10146,68 @@ func (m *ChannelMonitorHistoryMutation) IDs(ctx context.Context) ([]int64, error
}
fallthrough
case m.op.Is(OpUpdate | OpDelete):
- return m.Client().ChannelMonitorHistory.Query().Where(m.predicates...).IDs(ctx)
+ return m.Client().ChannelMonitorDailyRollup.Query().Where(m.predicates...).IDs(ctx)
default:
return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
}
}
+// SetDeletedAt sets the "deleted_at" field.
+func (m *ChannelMonitorDailyRollupMutation) SetDeletedAt(t time.Time) {
+ m.deleted_at = &t
+}
+
+// DeletedAt returns the value of the "deleted_at" field in the mutation.
+func (m *ChannelMonitorDailyRollupMutation) DeletedAt() (r time.Time, exists bool) {
+ v := m.deleted_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldDeletedAt returns the old "deleted_at" field's value of the ChannelMonitorDailyRollup entity.
+// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorDailyRollupMutation) OldDeletedAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldDeletedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldDeletedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldDeletedAt: %w", err)
+ }
+ return oldValue.DeletedAt, nil
+}
+
+// ClearDeletedAt clears the value of the "deleted_at" field.
+func (m *ChannelMonitorDailyRollupMutation) ClearDeletedAt() {
+ m.deleted_at = nil
+ m.clearedFields[channelmonitordailyrollup.FieldDeletedAt] = struct{}{}
+}
+
+// DeletedAtCleared returns if the "deleted_at" field was cleared in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) DeletedAtCleared() bool {
+ _, ok := m.clearedFields[channelmonitordailyrollup.FieldDeletedAt]
+ return ok
+}
+
+// ResetDeletedAt resets all changes to the "deleted_at" field.
+func (m *ChannelMonitorDailyRollupMutation) ResetDeletedAt() {
+ m.deleted_at = nil
+ delete(m.clearedFields, channelmonitordailyrollup.FieldDeletedAt)
+}
+
// SetMonitorID sets the "monitor_id" field.
-func (m *ChannelMonitorHistoryMutation) SetMonitorID(i int64) {
+func (m *ChannelMonitorDailyRollupMutation) SetMonitorID(i int64) {
m.monitor = &i
}
// MonitorID returns the value of the "monitor_id" field in the mutation.
-func (m *ChannelMonitorHistoryMutation) MonitorID() (r int64, exists bool) {
+func (m *ChannelMonitorDailyRollupMutation) MonitorID() (r int64, exists bool) {
v := m.monitor
if v == nil {
return
@@ -10065,10 +10215,10 @@ func (m *ChannelMonitorHistoryMutation) MonitorID() (r int64, exists bool) {
return *v, true
}
-// OldMonitorID returns the old "monitor_id" field's value of the ChannelMonitorHistory entity.
-// If the ChannelMonitorHistory object wasn't provided to the builder, the object is fetched from the database.
+// OldMonitorID returns the old "monitor_id" field's value of the ChannelMonitorDailyRollup entity.
+// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *ChannelMonitorHistoryMutation) OldMonitorID(ctx context.Context) (v int64, err error) {
+func (m *ChannelMonitorDailyRollupMutation) OldMonitorID(ctx context.Context) (v int64, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldMonitorID is only allowed on UpdateOne operations")
}
@@ -10083,17 +10233,17 @@ func (m *ChannelMonitorHistoryMutation) OldMonitorID(ctx context.Context) (v int
}
// ResetMonitorID resets all changes to the "monitor_id" field.
-func (m *ChannelMonitorHistoryMutation) ResetMonitorID() {
+func (m *ChannelMonitorDailyRollupMutation) ResetMonitorID() {
m.monitor = nil
}
// SetModel sets the "model" field.
-func (m *ChannelMonitorHistoryMutation) SetModel(s string) {
+func (m *ChannelMonitorDailyRollupMutation) SetModel(s string) {
m.model = &s
}
// Model returns the value of the "model" field in the mutation.
-func (m *ChannelMonitorHistoryMutation) Model() (r string, exists bool) {
+func (m *ChannelMonitorDailyRollupMutation) Model() (r string, exists bool) {
v := m.model
if v == nil {
return
@@ -10101,10 +10251,10 @@ func (m *ChannelMonitorHistoryMutation) Model() (r string, exists bool) {
return *v, true
}
-// OldModel returns the old "model" field's value of the ChannelMonitorHistory entity.
-// If the ChannelMonitorHistory object wasn't provided to the builder, the object is fetched from the database.
+// OldModel returns the old "model" field's value of the ChannelMonitorDailyRollup entity.
+// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *ChannelMonitorHistoryMutation) OldModel(ctx context.Context) (v string, err error) {
+func (m *ChannelMonitorDailyRollupMutation) OldModel(ctx context.Context) (v string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldModel is only allowed on UpdateOne operations")
}
@@ -10119,206 +10269,1682 @@ func (m *ChannelMonitorHistoryMutation) OldModel(ctx context.Context) (v string,
}
// ResetModel resets all changes to the "model" field.
-func (m *ChannelMonitorHistoryMutation) ResetModel() {
+func (m *ChannelMonitorDailyRollupMutation) ResetModel() {
m.model = nil
}
-// SetStatus sets the "status" field.
-func (m *ChannelMonitorHistoryMutation) SetStatus(c channelmonitorhistory.Status) {
- m.status = &c
+// SetBucketDate sets the "bucket_date" field.
+func (m *ChannelMonitorDailyRollupMutation) SetBucketDate(t time.Time) {
+ m.bucket_date = &t
}
-// Status returns the value of the "status" field in the mutation.
-func (m *ChannelMonitorHistoryMutation) Status() (r channelmonitorhistory.Status, exists bool) {
- v := m.status
+// BucketDate returns the value of the "bucket_date" field in the mutation.
+func (m *ChannelMonitorDailyRollupMutation) BucketDate() (r time.Time, exists bool) {
+ v := m.bucket_date
if v == nil {
return
}
return *v, true
}
-// OldStatus returns the old "status" field's value of the ChannelMonitorHistory entity.
-// If the ChannelMonitorHistory object wasn't provided to the builder, the object is fetched from the database.
+// OldBucketDate returns the old "bucket_date" field's value of the ChannelMonitorDailyRollup entity.
+// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *ChannelMonitorHistoryMutation) OldStatus(ctx context.Context) (v channelmonitorhistory.Status, err error) {
+func (m *ChannelMonitorDailyRollupMutation) OldBucketDate(ctx context.Context) (v time.Time, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldStatus is only allowed on UpdateOne operations")
+ return v, errors.New("OldBucketDate is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldStatus requires an ID field in the mutation")
+ return v, errors.New("OldBucketDate requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldStatus: %w", err)
+ return v, fmt.Errorf("querying old value for OldBucketDate: %w", err)
}
- return oldValue.Status, nil
+ return oldValue.BucketDate, nil
}
-// ResetStatus resets all changes to the "status" field.
-func (m *ChannelMonitorHistoryMutation) ResetStatus() {
- m.status = nil
+// ResetBucketDate resets all changes to the "bucket_date" field.
+func (m *ChannelMonitorDailyRollupMutation) ResetBucketDate() {
+ m.bucket_date = nil
}
-// SetLatencyMs sets the "latency_ms" field.
-func (m *ChannelMonitorHistoryMutation) SetLatencyMs(i int) {
- m.latency_ms = &i
- m.addlatency_ms = nil
+// SetTotalChecks sets the "total_checks" field.
+func (m *ChannelMonitorDailyRollupMutation) SetTotalChecks(i int) {
+ m.total_checks = &i
+ m.addtotal_checks = nil
}
-// LatencyMs returns the value of the "latency_ms" field in the mutation.
-func (m *ChannelMonitorHistoryMutation) LatencyMs() (r int, exists bool) {
- v := m.latency_ms
+// TotalChecks returns the value of the "total_checks" field in the mutation.
+func (m *ChannelMonitorDailyRollupMutation) TotalChecks() (r int, exists bool) {
+ v := m.total_checks
if v == nil {
return
}
return *v, true
}
-// OldLatencyMs returns the old "latency_ms" field's value of the ChannelMonitorHistory entity.
-// If the ChannelMonitorHistory object wasn't provided to the builder, the object is fetched from the database.
+// OldTotalChecks returns the old "total_checks" field's value of the ChannelMonitorDailyRollup entity.
+// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *ChannelMonitorHistoryMutation) OldLatencyMs(ctx context.Context) (v *int, err error) {
+func (m *ChannelMonitorDailyRollupMutation) OldTotalChecks(ctx context.Context) (v int, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldLatencyMs is only allowed on UpdateOne operations")
+ return v, errors.New("OldTotalChecks is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldLatencyMs requires an ID field in the mutation")
+ return v, errors.New("OldTotalChecks requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldLatencyMs: %w", err)
+ return v, fmt.Errorf("querying old value for OldTotalChecks: %w", err)
}
- return oldValue.LatencyMs, nil
+ return oldValue.TotalChecks, nil
}
-// AddLatencyMs adds i to the "latency_ms" field.
-func (m *ChannelMonitorHistoryMutation) AddLatencyMs(i int) {
- if m.addlatency_ms != nil {
- *m.addlatency_ms += i
+// AddTotalChecks adds i to the "total_checks" field.
+func (m *ChannelMonitorDailyRollupMutation) AddTotalChecks(i int) {
+ if m.addtotal_checks != nil {
+ *m.addtotal_checks += i
} else {
- m.addlatency_ms = &i
+ m.addtotal_checks = &i
}
}
-// AddedLatencyMs returns the value that was added to the "latency_ms" field in this mutation.
-func (m *ChannelMonitorHistoryMutation) AddedLatencyMs() (r int, exists bool) {
- v := m.addlatency_ms
+// AddedTotalChecks returns the value that was added to the "total_checks" field in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) AddedTotalChecks() (r int, exists bool) {
+ v := m.addtotal_checks
if v == nil {
return
}
return *v, true
}
-// ClearLatencyMs clears the value of the "latency_ms" field.
-func (m *ChannelMonitorHistoryMutation) ClearLatencyMs() {
- m.latency_ms = nil
- m.addlatency_ms = nil
- m.clearedFields[channelmonitorhistory.FieldLatencyMs] = struct{}{}
-}
-
-// LatencyMsCleared returns if the "latency_ms" field was cleared in this mutation.
-func (m *ChannelMonitorHistoryMutation) LatencyMsCleared() bool {
- _, ok := m.clearedFields[channelmonitorhistory.FieldLatencyMs]
- return ok
-}
-
-// ResetLatencyMs resets all changes to the "latency_ms" field.
-func (m *ChannelMonitorHistoryMutation) ResetLatencyMs() {
- m.latency_ms = nil
- m.addlatency_ms = nil
- delete(m.clearedFields, channelmonitorhistory.FieldLatencyMs)
+// ResetTotalChecks resets all changes to the "total_checks" field.
+func (m *ChannelMonitorDailyRollupMutation) ResetTotalChecks() {
+ m.total_checks = nil
+ m.addtotal_checks = nil
}
-// SetPingLatencyMs sets the "ping_latency_ms" field.
-func (m *ChannelMonitorHistoryMutation) SetPingLatencyMs(i int) {
- m.ping_latency_ms = &i
- m.addping_latency_ms = nil
+// SetOkCount sets the "ok_count" field.
+func (m *ChannelMonitorDailyRollupMutation) SetOkCount(i int) {
+ m.ok_count = &i
+ m.addok_count = nil
}
-// PingLatencyMs returns the value of the "ping_latency_ms" field in the mutation.
-func (m *ChannelMonitorHistoryMutation) PingLatencyMs() (r int, exists bool) {
- v := m.ping_latency_ms
+// OkCount returns the value of the "ok_count" field in the mutation.
+func (m *ChannelMonitorDailyRollupMutation) OkCount() (r int, exists bool) {
+ v := m.ok_count
if v == nil {
return
}
return *v, true
}
-// OldPingLatencyMs returns the old "ping_latency_ms" field's value of the ChannelMonitorHistory entity.
-// If the ChannelMonitorHistory object wasn't provided to the builder, the object is fetched from the database.
+// OldOkCount returns the old "ok_count" field's value of the ChannelMonitorDailyRollup entity.
+// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *ChannelMonitorHistoryMutation) OldPingLatencyMs(ctx context.Context) (v *int, err error) {
+func (m *ChannelMonitorDailyRollupMutation) OldOkCount(ctx context.Context) (v int, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldPingLatencyMs is only allowed on UpdateOne operations")
+ return v, errors.New("OldOkCount is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldPingLatencyMs requires an ID field in the mutation")
+ return v, errors.New("OldOkCount requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldPingLatencyMs: %w", err)
+ return v, fmt.Errorf("querying old value for OldOkCount: %w", err)
}
- return oldValue.PingLatencyMs, nil
+ return oldValue.OkCount, nil
}
-// AddPingLatencyMs adds i to the "ping_latency_ms" field.
-func (m *ChannelMonitorHistoryMutation) AddPingLatencyMs(i int) {
- if m.addping_latency_ms != nil {
- *m.addping_latency_ms += i
+// AddOkCount adds i to the "ok_count" field.
+func (m *ChannelMonitorDailyRollupMutation) AddOkCount(i int) {
+ if m.addok_count != nil {
+ *m.addok_count += i
} else {
- m.addping_latency_ms = &i
+ m.addok_count = &i
}
}
-// AddedPingLatencyMs returns the value that was added to the "ping_latency_ms" field in this mutation.
-func (m *ChannelMonitorHistoryMutation) AddedPingLatencyMs() (r int, exists bool) {
- v := m.addping_latency_ms
+// AddedOkCount returns the value that was added to the "ok_count" field in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) AddedOkCount() (r int, exists bool) {
+ v := m.addok_count
if v == nil {
return
}
return *v, true
}
-// ClearPingLatencyMs clears the value of the "ping_latency_ms" field.
-func (m *ChannelMonitorHistoryMutation) ClearPingLatencyMs() {
- m.ping_latency_ms = nil
- m.addping_latency_ms = nil
- m.clearedFields[channelmonitorhistory.FieldPingLatencyMs] = struct{}{}
-}
-
-// PingLatencyMsCleared returns if the "ping_latency_ms" field was cleared in this mutation.
-func (m *ChannelMonitorHistoryMutation) PingLatencyMsCleared() bool {
- _, ok := m.clearedFields[channelmonitorhistory.FieldPingLatencyMs]
- return ok
-}
-
-// ResetPingLatencyMs resets all changes to the "ping_latency_ms" field.
-func (m *ChannelMonitorHistoryMutation) ResetPingLatencyMs() {
- m.ping_latency_ms = nil
- m.addping_latency_ms = nil
- delete(m.clearedFields, channelmonitorhistory.FieldPingLatencyMs)
+// ResetOkCount resets all changes to the "ok_count" field.
+func (m *ChannelMonitorDailyRollupMutation) ResetOkCount() {
+ m.ok_count = nil
+ m.addok_count = nil
}
-// SetMessage sets the "message" field.
-func (m *ChannelMonitorHistoryMutation) SetMessage(s string) {
- m.message = &s
+// SetOperationalCount sets the "operational_count" field.
+func (m *ChannelMonitorDailyRollupMutation) SetOperationalCount(i int) {
+ m.operational_count = &i
+ m.addoperational_count = nil
}
-// Message returns the value of the "message" field in the mutation.
-func (m *ChannelMonitorHistoryMutation) Message() (r string, exists bool) {
- v := m.message
+// OperationalCount returns the value of the "operational_count" field in the mutation.
+func (m *ChannelMonitorDailyRollupMutation) OperationalCount() (r int, exists bool) {
+ v := m.operational_count
if v == nil {
return
}
return *v, true
}
-// OldMessage returns the old "message" field's value of the ChannelMonitorHistory entity.
-// If the ChannelMonitorHistory object wasn't provided to the builder, the object is fetched from the database.
+// OldOperationalCount returns the old "operational_count" field's value of the ChannelMonitorDailyRollup entity.
+// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *ChannelMonitorHistoryMutation) OldMessage(ctx context.Context) (v string, err error) {
+func (m *ChannelMonitorDailyRollupMutation) OldOperationalCount(ctx context.Context) (v int, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldMessage is only allowed on UpdateOne operations")
+ return v, errors.New("OldOperationalCount is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldOperationalCount requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldOperationalCount: %w", err)
+ }
+ return oldValue.OperationalCount, nil
+}
+
+// AddOperationalCount adds i to the "operational_count" field.
+func (m *ChannelMonitorDailyRollupMutation) AddOperationalCount(i int) {
+ if m.addoperational_count != nil {
+ *m.addoperational_count += i
+ } else {
+ m.addoperational_count = &i
+ }
+}
+
+// AddedOperationalCount returns the value that was added to the "operational_count" field in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) AddedOperationalCount() (r int, exists bool) {
+ v := m.addoperational_count
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetOperationalCount resets all changes to the "operational_count" field.
+func (m *ChannelMonitorDailyRollupMutation) ResetOperationalCount() {
+ m.operational_count = nil
+ m.addoperational_count = nil
+}
+
+// SetDegradedCount sets the "degraded_count" field.
+func (m *ChannelMonitorDailyRollupMutation) SetDegradedCount(i int) {
+ m.degraded_count = &i
+ m.adddegraded_count = nil
+}
+
+// DegradedCount returns the value of the "degraded_count" field in the mutation.
+func (m *ChannelMonitorDailyRollupMutation) DegradedCount() (r int, exists bool) {
+ v := m.degraded_count
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldDegradedCount returns the old "degraded_count" field's value of the ChannelMonitorDailyRollup entity.
+// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorDailyRollupMutation) OldDegradedCount(ctx context.Context) (v int, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldDegradedCount is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldDegradedCount requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldDegradedCount: %w", err)
+ }
+ return oldValue.DegradedCount, nil
+}
+
+// AddDegradedCount adds i to the "degraded_count" field.
+func (m *ChannelMonitorDailyRollupMutation) AddDegradedCount(i int) {
+ if m.adddegraded_count != nil {
+ *m.adddegraded_count += i
+ } else {
+ m.adddegraded_count = &i
+ }
+}
+
+// AddedDegradedCount returns the value that was added to the "degraded_count" field in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) AddedDegradedCount() (r int, exists bool) {
+ v := m.adddegraded_count
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetDegradedCount resets all changes to the "degraded_count" field.
+func (m *ChannelMonitorDailyRollupMutation) ResetDegradedCount() {
+ m.degraded_count = nil
+ m.adddegraded_count = nil
+}
+
+// SetFailedCount sets the "failed_count" field.
+func (m *ChannelMonitorDailyRollupMutation) SetFailedCount(i int) {
+ m.failed_count = &i
+ m.addfailed_count = nil
+}
+
+// FailedCount returns the value of the "failed_count" field in the mutation.
+func (m *ChannelMonitorDailyRollupMutation) FailedCount() (r int, exists bool) {
+ v := m.failed_count
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldFailedCount returns the old "failed_count" field's value of the ChannelMonitorDailyRollup entity.
+// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorDailyRollupMutation) OldFailedCount(ctx context.Context) (v int, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldFailedCount is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldFailedCount requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldFailedCount: %w", err)
+ }
+ return oldValue.FailedCount, nil
+}
+
+// AddFailedCount adds i to the "failed_count" field.
+func (m *ChannelMonitorDailyRollupMutation) AddFailedCount(i int) {
+ if m.addfailed_count != nil {
+ *m.addfailed_count += i
+ } else {
+ m.addfailed_count = &i
+ }
+}
+
+// AddedFailedCount returns the value that was added to the "failed_count" field in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) AddedFailedCount() (r int, exists bool) {
+ v := m.addfailed_count
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetFailedCount resets all changes to the "failed_count" field.
+func (m *ChannelMonitorDailyRollupMutation) ResetFailedCount() {
+ m.failed_count = nil
+ m.addfailed_count = nil
+}
+
+// SetErrorCount sets the "error_count" field.
+func (m *ChannelMonitorDailyRollupMutation) SetErrorCount(i int) {
+ m.error_count = &i
+ m.adderror_count = nil
+}
+
+// ErrorCount returns the value of the "error_count" field in the mutation.
+func (m *ChannelMonitorDailyRollupMutation) ErrorCount() (r int, exists bool) {
+ v := m.error_count
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldErrorCount returns the old "error_count" field's value of the ChannelMonitorDailyRollup entity.
+// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorDailyRollupMutation) OldErrorCount(ctx context.Context) (v int, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldErrorCount is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldErrorCount requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldErrorCount: %w", err)
+ }
+ return oldValue.ErrorCount, nil
+}
+
+// AddErrorCount adds i to the "error_count" field.
+func (m *ChannelMonitorDailyRollupMutation) AddErrorCount(i int) {
+ if m.adderror_count != nil {
+ *m.adderror_count += i
+ } else {
+ m.adderror_count = &i
+ }
+}
+
+// AddedErrorCount returns the value that was added to the "error_count" field in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) AddedErrorCount() (r int, exists bool) {
+ v := m.adderror_count
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetErrorCount resets all changes to the "error_count" field.
+func (m *ChannelMonitorDailyRollupMutation) ResetErrorCount() {
+ m.error_count = nil
+ m.adderror_count = nil
+}
+
+// SetSumLatencyMs sets the "sum_latency_ms" field.
+func (m *ChannelMonitorDailyRollupMutation) SetSumLatencyMs(i int64) {
+ m.sum_latency_ms = &i
+ m.addsum_latency_ms = nil
+}
+
+// SumLatencyMs returns the value of the "sum_latency_ms" field in the mutation.
+func (m *ChannelMonitorDailyRollupMutation) SumLatencyMs() (r int64, exists bool) {
+ v := m.sum_latency_ms
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldSumLatencyMs returns the old "sum_latency_ms" field's value of the ChannelMonitorDailyRollup entity.
+// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorDailyRollupMutation) OldSumLatencyMs(ctx context.Context) (v int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldSumLatencyMs is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldSumLatencyMs requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldSumLatencyMs: %w", err)
+ }
+ return oldValue.SumLatencyMs, nil
+}
+
+// AddSumLatencyMs adds i to the "sum_latency_ms" field.
+func (m *ChannelMonitorDailyRollupMutation) AddSumLatencyMs(i int64) {
+ if m.addsum_latency_ms != nil {
+ *m.addsum_latency_ms += i
+ } else {
+ m.addsum_latency_ms = &i
+ }
+}
+
+// AddedSumLatencyMs returns the value that was added to the "sum_latency_ms" field in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) AddedSumLatencyMs() (r int64, exists bool) {
+ v := m.addsum_latency_ms
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetSumLatencyMs resets all changes to the "sum_latency_ms" field.
+func (m *ChannelMonitorDailyRollupMutation) ResetSumLatencyMs() {
+ m.sum_latency_ms = nil
+ m.addsum_latency_ms = nil
+}
+
+// SetCountLatency sets the "count_latency" field.
+func (m *ChannelMonitorDailyRollupMutation) SetCountLatency(i int) {
+ m.count_latency = &i
+ m.addcount_latency = nil
+}
+
+// CountLatency returns the value of the "count_latency" field in the mutation.
+func (m *ChannelMonitorDailyRollupMutation) CountLatency() (r int, exists bool) {
+ v := m.count_latency
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCountLatency returns the old "count_latency" field's value of the ChannelMonitorDailyRollup entity.
+// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorDailyRollupMutation) OldCountLatency(ctx context.Context) (v int, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCountLatency is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCountLatency requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCountLatency: %w", err)
+ }
+ return oldValue.CountLatency, nil
+}
+
+// AddCountLatency adds i to the "count_latency" field.
+func (m *ChannelMonitorDailyRollupMutation) AddCountLatency(i int) {
+ if m.addcount_latency != nil {
+ *m.addcount_latency += i
+ } else {
+ m.addcount_latency = &i
+ }
+}
+
+// AddedCountLatency returns the value that was added to the "count_latency" field in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) AddedCountLatency() (r int, exists bool) {
+ v := m.addcount_latency
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetCountLatency resets all changes to the "count_latency" field.
+func (m *ChannelMonitorDailyRollupMutation) ResetCountLatency() {
+ m.count_latency = nil
+ m.addcount_latency = nil
+}
+
+// SetSumPingLatencyMs sets the "sum_ping_latency_ms" field.
+func (m *ChannelMonitorDailyRollupMutation) SetSumPingLatencyMs(i int64) {
+ m.sum_ping_latency_ms = &i
+ m.addsum_ping_latency_ms = nil
+}
+
+// SumPingLatencyMs returns the value of the "sum_ping_latency_ms" field in the mutation.
+func (m *ChannelMonitorDailyRollupMutation) SumPingLatencyMs() (r int64, exists bool) {
+ v := m.sum_ping_latency_ms
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldSumPingLatencyMs returns the old "sum_ping_latency_ms" field's value of the ChannelMonitorDailyRollup entity.
+// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorDailyRollupMutation) OldSumPingLatencyMs(ctx context.Context) (v int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldSumPingLatencyMs is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldSumPingLatencyMs requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldSumPingLatencyMs: %w", err)
+ }
+ return oldValue.SumPingLatencyMs, nil
+}
+
+// AddSumPingLatencyMs adds i to the "sum_ping_latency_ms" field.
+func (m *ChannelMonitorDailyRollupMutation) AddSumPingLatencyMs(i int64) {
+ if m.addsum_ping_latency_ms != nil {
+ *m.addsum_ping_latency_ms += i
+ } else {
+ m.addsum_ping_latency_ms = &i
+ }
+}
+
+// AddedSumPingLatencyMs returns the value that was added to the "sum_ping_latency_ms" field in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) AddedSumPingLatencyMs() (r int64, exists bool) {
+ v := m.addsum_ping_latency_ms
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetSumPingLatencyMs resets all changes to the "sum_ping_latency_ms" field.
+func (m *ChannelMonitorDailyRollupMutation) ResetSumPingLatencyMs() {
+ m.sum_ping_latency_ms = nil
+ m.addsum_ping_latency_ms = nil
+}
+
+// SetCountPingLatency sets the "count_ping_latency" field.
+func (m *ChannelMonitorDailyRollupMutation) SetCountPingLatency(i int) {
+ m.count_ping_latency = &i
+ m.addcount_ping_latency = nil
+}
+
+// CountPingLatency returns the value of the "count_ping_latency" field in the mutation.
+func (m *ChannelMonitorDailyRollupMutation) CountPingLatency() (r int, exists bool) {
+ v := m.count_ping_latency
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCountPingLatency returns the old "count_ping_latency" field's value of the ChannelMonitorDailyRollup entity.
+// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorDailyRollupMutation) OldCountPingLatency(ctx context.Context) (v int, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCountPingLatency is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCountPingLatency requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCountPingLatency: %w", err)
+ }
+ return oldValue.CountPingLatency, nil
+}
+
+// AddCountPingLatency adds i to the "count_ping_latency" field.
+func (m *ChannelMonitorDailyRollupMutation) AddCountPingLatency(i int) {
+ if m.addcount_ping_latency != nil {
+ *m.addcount_ping_latency += i
+ } else {
+ m.addcount_ping_latency = &i
+ }
+}
+
+// AddedCountPingLatency returns the value that was added to the "count_ping_latency" field in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) AddedCountPingLatency() (r int, exists bool) {
+ v := m.addcount_ping_latency
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetCountPingLatency resets all changes to the "count_ping_latency" field.
+func (m *ChannelMonitorDailyRollupMutation) ResetCountPingLatency() {
+ m.count_ping_latency = nil
+ m.addcount_ping_latency = nil
+}
+
+// SetComputedAt sets the "computed_at" field.
+func (m *ChannelMonitorDailyRollupMutation) SetComputedAt(t time.Time) {
+ m.computed_at = &t
+}
+
+// ComputedAt returns the value of the "computed_at" field in the mutation.
+func (m *ChannelMonitorDailyRollupMutation) ComputedAt() (r time.Time, exists bool) {
+ v := m.computed_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldComputedAt returns the old "computed_at" field's value of the ChannelMonitorDailyRollup entity.
+// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorDailyRollupMutation) OldComputedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldComputedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldComputedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldComputedAt: %w", err)
+ }
+ return oldValue.ComputedAt, nil
+}
+
+// ResetComputedAt resets all changes to the "computed_at" field.
+func (m *ChannelMonitorDailyRollupMutation) ResetComputedAt() {
+ m.computed_at = nil
+}
+
+// ClearMonitor clears the "monitor" edge to the ChannelMonitor entity.
+func (m *ChannelMonitorDailyRollupMutation) ClearMonitor() {
+ m.clearedmonitor = true
+ m.clearedFields[channelmonitordailyrollup.FieldMonitorID] = struct{}{}
+}
+
+// MonitorCleared reports if the "monitor" edge to the ChannelMonitor entity was cleared.
+func (m *ChannelMonitorDailyRollupMutation) MonitorCleared() bool {
+ return m.clearedmonitor
+}
+
+// MonitorIDs returns the "monitor" edge IDs in the mutation.
+// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use
+// MonitorID instead. It exists only for internal usage by the builders.
+func (m *ChannelMonitorDailyRollupMutation) MonitorIDs() (ids []int64) {
+ if id := m.monitor; id != nil {
+ ids = append(ids, *id)
+ }
+ return
+}
+
+// ResetMonitor resets all changes to the "monitor" edge.
+func (m *ChannelMonitorDailyRollupMutation) ResetMonitor() {
+ m.monitor = nil
+ m.clearedmonitor = false
+}
+
+// Where appends a list predicates to the ChannelMonitorDailyRollupMutation builder.
+func (m *ChannelMonitorDailyRollupMutation) Where(ps ...predicate.ChannelMonitorDailyRollup) {
+ m.predicates = append(m.predicates, ps...)
+}
+
+// WhereP appends storage-level predicates to the ChannelMonitorDailyRollupMutation builder. Using this method,
+// users can use type-assertion to append predicates that do not depend on any generated package.
+func (m *ChannelMonitorDailyRollupMutation) WhereP(ps ...func(*sql.Selector)) {
+ p := make([]predicate.ChannelMonitorDailyRollup, len(ps))
+ for i := range ps {
+ p[i] = ps[i]
+ }
+ m.Where(p...)
+}
+
+// Op returns the operation name.
+func (m *ChannelMonitorDailyRollupMutation) Op() Op {
+ return m.op
+}
+
+// SetOp allows setting the mutation operation.
+func (m *ChannelMonitorDailyRollupMutation) SetOp(op Op) {
+ m.op = op
+}
+
+// Type returns the node type of this mutation (ChannelMonitorDailyRollup).
+func (m *ChannelMonitorDailyRollupMutation) Type() string {
+ return m.typ
+}
+
+// Fields returns all fields that were changed during this mutation. Note that in
+// order to get all numeric fields that were incremented/decremented, call
+// AddedFields().
+func (m *ChannelMonitorDailyRollupMutation) Fields() []string {
+ fields := make([]string, 0, 15)
+ if m.deleted_at != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldDeletedAt)
+ }
+ if m.monitor != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldMonitorID)
+ }
+ if m.model != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldModel)
+ }
+ if m.bucket_date != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldBucketDate)
+ }
+ if m.total_checks != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldTotalChecks)
+ }
+ if m.ok_count != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldOkCount)
+ }
+ if m.operational_count != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldOperationalCount)
+ }
+ if m.degraded_count != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldDegradedCount)
+ }
+ if m.failed_count != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldFailedCount)
+ }
+ if m.error_count != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldErrorCount)
+ }
+ if m.sum_latency_ms != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldSumLatencyMs)
+ }
+ if m.count_latency != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldCountLatency)
+ }
+ if m.sum_ping_latency_ms != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldSumPingLatencyMs)
+ }
+ if m.count_ping_latency != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldCountPingLatency)
+ }
+ if m.computed_at != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldComputedAt)
+ }
+ return fields
+}
+
+// Field returns the value of a field with the given name. The second boolean
+// return value indicates that this field was not set, or was not defined in the
+// schema.
+func (m *ChannelMonitorDailyRollupMutation) Field(name string) (ent.Value, bool) {
+ switch name {
+ case channelmonitordailyrollup.FieldDeletedAt:
+ return m.DeletedAt()
+ case channelmonitordailyrollup.FieldMonitorID:
+ return m.MonitorID()
+ case channelmonitordailyrollup.FieldModel:
+ return m.Model()
+ case channelmonitordailyrollup.FieldBucketDate:
+ return m.BucketDate()
+ case channelmonitordailyrollup.FieldTotalChecks:
+ return m.TotalChecks()
+ case channelmonitordailyrollup.FieldOkCount:
+ return m.OkCount()
+ case channelmonitordailyrollup.FieldOperationalCount:
+ return m.OperationalCount()
+ case channelmonitordailyrollup.FieldDegradedCount:
+ return m.DegradedCount()
+ case channelmonitordailyrollup.FieldFailedCount:
+ return m.FailedCount()
+ case channelmonitordailyrollup.FieldErrorCount:
+ return m.ErrorCount()
+ case channelmonitordailyrollup.FieldSumLatencyMs:
+ return m.SumLatencyMs()
+ case channelmonitordailyrollup.FieldCountLatency:
+ return m.CountLatency()
+ case channelmonitordailyrollup.FieldSumPingLatencyMs:
+ return m.SumPingLatencyMs()
+ case channelmonitordailyrollup.FieldCountPingLatency:
+ return m.CountPingLatency()
+ case channelmonitordailyrollup.FieldComputedAt:
+ return m.ComputedAt()
+ }
+ return nil, false
+}
+
+// OldField returns the old value of the field from the database. An error is
+// returned if the mutation operation is not UpdateOne, or the query to the
+// database failed.
+func (m *ChannelMonitorDailyRollupMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
+ switch name {
+ case channelmonitordailyrollup.FieldDeletedAt:
+ return m.OldDeletedAt(ctx)
+ case channelmonitordailyrollup.FieldMonitorID:
+ return m.OldMonitorID(ctx)
+ case channelmonitordailyrollup.FieldModel:
+ return m.OldModel(ctx)
+ case channelmonitordailyrollup.FieldBucketDate:
+ return m.OldBucketDate(ctx)
+ case channelmonitordailyrollup.FieldTotalChecks:
+ return m.OldTotalChecks(ctx)
+ case channelmonitordailyrollup.FieldOkCount:
+ return m.OldOkCount(ctx)
+ case channelmonitordailyrollup.FieldOperationalCount:
+ return m.OldOperationalCount(ctx)
+ case channelmonitordailyrollup.FieldDegradedCount:
+ return m.OldDegradedCount(ctx)
+ case channelmonitordailyrollup.FieldFailedCount:
+ return m.OldFailedCount(ctx)
+ case channelmonitordailyrollup.FieldErrorCount:
+ return m.OldErrorCount(ctx)
+ case channelmonitordailyrollup.FieldSumLatencyMs:
+ return m.OldSumLatencyMs(ctx)
+ case channelmonitordailyrollup.FieldCountLatency:
+ return m.OldCountLatency(ctx)
+ case channelmonitordailyrollup.FieldSumPingLatencyMs:
+ return m.OldSumPingLatencyMs(ctx)
+ case channelmonitordailyrollup.FieldCountPingLatency:
+ return m.OldCountPingLatency(ctx)
+ case channelmonitordailyrollup.FieldComputedAt:
+ return m.OldComputedAt(ctx)
+ }
+ return nil, fmt.Errorf("unknown ChannelMonitorDailyRollup field %s", name)
+}
+
+// SetField sets the value of a field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *ChannelMonitorDailyRollupMutation) SetField(name string, value ent.Value) error {
+ switch name {
+ case channelmonitordailyrollup.FieldDeletedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetDeletedAt(v)
+ return nil
+ case channelmonitordailyrollup.FieldMonitorID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetMonitorID(v)
+ return nil
+ case channelmonitordailyrollup.FieldModel:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetModel(v)
+ return nil
+ case channelmonitordailyrollup.FieldBucketDate:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetBucketDate(v)
+ return nil
+ case channelmonitordailyrollup.FieldTotalChecks:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetTotalChecks(v)
+ return nil
+ case channelmonitordailyrollup.FieldOkCount:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetOkCount(v)
+ return nil
+ case channelmonitordailyrollup.FieldOperationalCount:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetOperationalCount(v)
+ return nil
+ case channelmonitordailyrollup.FieldDegradedCount:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetDegradedCount(v)
+ return nil
+ case channelmonitordailyrollup.FieldFailedCount:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetFailedCount(v)
+ return nil
+ case channelmonitordailyrollup.FieldErrorCount:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetErrorCount(v)
+ return nil
+ case channelmonitordailyrollup.FieldSumLatencyMs:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetSumLatencyMs(v)
+ return nil
+ case channelmonitordailyrollup.FieldCountLatency:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCountLatency(v)
+ return nil
+ case channelmonitordailyrollup.FieldSumPingLatencyMs:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetSumPingLatencyMs(v)
+ return nil
+ case channelmonitordailyrollup.FieldCountPingLatency:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCountPingLatency(v)
+ return nil
+ case channelmonitordailyrollup.FieldComputedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetComputedAt(v)
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitorDailyRollup field %s", name)
+}
+
+// AddedFields returns all numeric fields that were incremented/decremented during
+// this mutation.
+func (m *ChannelMonitorDailyRollupMutation) AddedFields() []string {
+ var fields []string
+ if m.addtotal_checks != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldTotalChecks)
+ }
+ if m.addok_count != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldOkCount)
+ }
+ if m.addoperational_count != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldOperationalCount)
+ }
+ if m.adddegraded_count != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldDegradedCount)
+ }
+ if m.addfailed_count != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldFailedCount)
+ }
+ if m.adderror_count != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldErrorCount)
+ }
+ if m.addsum_latency_ms != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldSumLatencyMs)
+ }
+ if m.addcount_latency != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldCountLatency)
+ }
+ if m.addsum_ping_latency_ms != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldSumPingLatencyMs)
+ }
+ if m.addcount_ping_latency != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldCountPingLatency)
+ }
+ return fields
+}
+
+// AddedField returns the numeric value that was incremented/decremented on a field
+// with the given name. The second boolean return value indicates that this field
+// was not set, or was not defined in the schema.
+func (m *ChannelMonitorDailyRollupMutation) AddedField(name string) (ent.Value, bool) {
+ switch name {
+ case channelmonitordailyrollup.FieldTotalChecks:
+ return m.AddedTotalChecks()
+ case channelmonitordailyrollup.FieldOkCount:
+ return m.AddedOkCount()
+ case channelmonitordailyrollup.FieldOperationalCount:
+ return m.AddedOperationalCount()
+ case channelmonitordailyrollup.FieldDegradedCount:
+ return m.AddedDegradedCount()
+ case channelmonitordailyrollup.FieldFailedCount:
+ return m.AddedFailedCount()
+ case channelmonitordailyrollup.FieldErrorCount:
+ return m.AddedErrorCount()
+ case channelmonitordailyrollup.FieldSumLatencyMs:
+ return m.AddedSumLatencyMs()
+ case channelmonitordailyrollup.FieldCountLatency:
+ return m.AddedCountLatency()
+ case channelmonitordailyrollup.FieldSumPingLatencyMs:
+ return m.AddedSumPingLatencyMs()
+ case channelmonitordailyrollup.FieldCountPingLatency:
+ return m.AddedCountPingLatency()
+ }
+ return nil, false
+}
+
+// AddField adds the value to the field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *ChannelMonitorDailyRollupMutation) AddField(name string, value ent.Value) error {
+ switch name {
+ case channelmonitordailyrollup.FieldTotalChecks:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddTotalChecks(v)
+ return nil
+ case channelmonitordailyrollup.FieldOkCount:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddOkCount(v)
+ return nil
+ case channelmonitordailyrollup.FieldOperationalCount:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddOperationalCount(v)
+ return nil
+ case channelmonitordailyrollup.FieldDegradedCount:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddDegradedCount(v)
+ return nil
+ case channelmonitordailyrollup.FieldFailedCount:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddFailedCount(v)
+ return nil
+ case channelmonitordailyrollup.FieldErrorCount:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddErrorCount(v)
+ return nil
+ case channelmonitordailyrollup.FieldSumLatencyMs:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddSumLatencyMs(v)
+ return nil
+ case channelmonitordailyrollup.FieldCountLatency:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddCountLatency(v)
+ return nil
+ case channelmonitordailyrollup.FieldSumPingLatencyMs:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddSumPingLatencyMs(v)
+ return nil
+ case channelmonitordailyrollup.FieldCountPingLatency:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddCountPingLatency(v)
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitorDailyRollup numeric field %s", name)
+}
+
+// ClearedFields returns all nullable fields that were cleared during this
+// mutation.
+func (m *ChannelMonitorDailyRollupMutation) ClearedFields() []string {
+ var fields []string
+ if m.FieldCleared(channelmonitordailyrollup.FieldDeletedAt) {
+ fields = append(fields, channelmonitordailyrollup.FieldDeletedAt)
+ }
+ return fields
+}
+
+// FieldCleared returns a boolean indicating if a field with the given name was
+// cleared in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) FieldCleared(name string) bool {
+ _, ok := m.clearedFields[name]
+ return ok
+}
+
+// ClearField clears the value of the field with the given name. It returns an
+// error if the field is not defined in the schema.
+func (m *ChannelMonitorDailyRollupMutation) ClearField(name string) error {
+ switch name {
+ case channelmonitordailyrollup.FieldDeletedAt:
+ m.ClearDeletedAt()
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitorDailyRollup nullable field %s", name)
+}
+
+// ResetField resets all changes in the mutation for the field with the given name.
+// It returns an error if the field is not defined in the schema.
+func (m *ChannelMonitorDailyRollupMutation) ResetField(name string) error {
+ switch name {
+ case channelmonitordailyrollup.FieldDeletedAt:
+ m.ResetDeletedAt()
+ return nil
+ case channelmonitordailyrollup.FieldMonitorID:
+ m.ResetMonitorID()
+ return nil
+ case channelmonitordailyrollup.FieldModel:
+ m.ResetModel()
+ return nil
+ case channelmonitordailyrollup.FieldBucketDate:
+ m.ResetBucketDate()
+ return nil
+ case channelmonitordailyrollup.FieldTotalChecks:
+ m.ResetTotalChecks()
+ return nil
+ case channelmonitordailyrollup.FieldOkCount:
+ m.ResetOkCount()
+ return nil
+ case channelmonitordailyrollup.FieldOperationalCount:
+ m.ResetOperationalCount()
+ return nil
+ case channelmonitordailyrollup.FieldDegradedCount:
+ m.ResetDegradedCount()
+ return nil
+ case channelmonitordailyrollup.FieldFailedCount:
+ m.ResetFailedCount()
+ return nil
+ case channelmonitordailyrollup.FieldErrorCount:
+ m.ResetErrorCount()
+ return nil
+ case channelmonitordailyrollup.FieldSumLatencyMs:
+ m.ResetSumLatencyMs()
+ return nil
+ case channelmonitordailyrollup.FieldCountLatency:
+ m.ResetCountLatency()
+ return nil
+ case channelmonitordailyrollup.FieldSumPingLatencyMs:
+ m.ResetSumPingLatencyMs()
+ return nil
+ case channelmonitordailyrollup.FieldCountPingLatency:
+ m.ResetCountPingLatency()
+ return nil
+ case channelmonitordailyrollup.FieldComputedAt:
+ m.ResetComputedAt()
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitorDailyRollup field %s", name)
+}
+
+// AddedEdges returns all edge names that were set/added in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) AddedEdges() []string {
+ edges := make([]string, 0, 1)
+ if m.monitor != nil {
+ edges = append(edges, channelmonitordailyrollup.EdgeMonitor)
+ }
+ return edges
+}
+
+// AddedIDs returns all IDs (to other nodes) that were added for the given edge
+// name in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) AddedIDs(name string) []ent.Value {
+ switch name {
+ case channelmonitordailyrollup.EdgeMonitor:
+ if id := m.monitor; id != nil {
+ return []ent.Value{*id}
+ }
+ }
+ return nil
+}
+
+// RemovedEdges returns all edge names that were removed in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) RemovedEdges() []string {
+ edges := make([]string, 0, 1)
+ return edges
+}
+
+// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
+// the given name in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) RemovedIDs(name string) []ent.Value {
+ return nil
+}
+
+// ClearedEdges returns all edge names that were cleared in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) ClearedEdges() []string {
+ edges := make([]string, 0, 1)
+ if m.clearedmonitor {
+ edges = append(edges, channelmonitordailyrollup.EdgeMonitor)
+ }
+ return edges
+}
+
+// EdgeCleared returns a boolean which indicates if the edge with the given name
+// was cleared in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) EdgeCleared(name string) bool {
+ switch name {
+ case channelmonitordailyrollup.EdgeMonitor:
+ return m.clearedmonitor
+ }
+ return false
+}
+
+// ClearEdge clears the value of the edge with the given name. It returns an error
+// if that edge is not defined in the schema.
+func (m *ChannelMonitorDailyRollupMutation) ClearEdge(name string) error {
+ switch name {
+ case channelmonitordailyrollup.EdgeMonitor:
+ m.ClearMonitor()
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitorDailyRollup unique edge %s", name)
+}
+
+// ResetEdge resets all changes to the edge with the given name in this mutation.
+// It returns an error if the edge is not defined in the schema.
+func (m *ChannelMonitorDailyRollupMutation) ResetEdge(name string) error {
+ switch name {
+ case channelmonitordailyrollup.EdgeMonitor:
+ m.ResetMonitor()
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitorDailyRollup edge %s", name)
+}
+
+// ChannelMonitorHistoryMutation represents an operation that mutates the ChannelMonitorHistory nodes in the graph.
+type ChannelMonitorHistoryMutation struct {
+ config
+ op Op
+ typ string
+ id *int64
+ deleted_at *time.Time
+ model *string
+ status *channelmonitorhistory.Status
+ latency_ms *int
+ addlatency_ms *int
+ ping_latency_ms *int
+ addping_latency_ms *int
+ message *string
+ checked_at *time.Time
+ clearedFields map[string]struct{}
+ monitor *int64
+ clearedmonitor bool
+ done bool
+ oldValue func(context.Context) (*ChannelMonitorHistory, error)
+ predicates []predicate.ChannelMonitorHistory
+}
+
+var _ ent.Mutation = (*ChannelMonitorHistoryMutation)(nil)
+
+// channelmonitorhistoryOption allows management of the mutation configuration using functional options.
+type channelmonitorhistoryOption func(*ChannelMonitorHistoryMutation)
+
+// newChannelMonitorHistoryMutation creates new mutation for the ChannelMonitorHistory entity.
+func newChannelMonitorHistoryMutation(c config, op Op, opts ...channelmonitorhistoryOption) *ChannelMonitorHistoryMutation {
+ m := &ChannelMonitorHistoryMutation{
+ config: c,
+ op: op,
+ typ: TypeChannelMonitorHistory,
+ clearedFields: make(map[string]struct{}),
+ }
+ for _, opt := range opts {
+ opt(m)
+ }
+ return m
+}
+
+// withChannelMonitorHistoryID sets the ID field of the mutation.
+func withChannelMonitorHistoryID(id int64) channelmonitorhistoryOption {
+ return func(m *ChannelMonitorHistoryMutation) {
+ var (
+ err error
+ once sync.Once
+ value *ChannelMonitorHistory
+ )
+ m.oldValue = func(ctx context.Context) (*ChannelMonitorHistory, error) {
+ once.Do(func() {
+ if m.done {
+ err = errors.New("querying old values post mutation is not allowed")
+ } else {
+ value, err = m.Client().ChannelMonitorHistory.Get(ctx, id)
+ }
+ })
+ return value, err
+ }
+ m.id = &id
+ }
+}
+
+// withChannelMonitorHistory sets the old ChannelMonitorHistory of the mutation.
+func withChannelMonitorHistory(node *ChannelMonitorHistory) channelmonitorhistoryOption {
+ return func(m *ChannelMonitorHistoryMutation) {
+ m.oldValue = func(context.Context) (*ChannelMonitorHistory, error) {
+ return node, nil
+ }
+ m.id = &node.ID
+ }
+}
+
+// Client returns a new `ent.Client` from the mutation. If the mutation was
+// executed in a transaction (ent.Tx), a transactional client is returned.
+func (m ChannelMonitorHistoryMutation) Client() *Client {
+ client := &Client{config: m.config}
+ client.init()
+ return client
+}
+
+// Tx returns an `ent.Tx` for mutations that were executed in transactions;
+// it returns an error otherwise.
+func (m ChannelMonitorHistoryMutation) Tx() (*Tx, error) {
+ if _, ok := m.driver.(*txDriver); !ok {
+ return nil, errors.New("ent: mutation is not running in a transaction")
+ }
+ tx := &Tx{config: m.config}
+ tx.init()
+ return tx, nil
+}
+
+// ID returns the ID value in the mutation. Note that the ID is only available
+// if it was provided to the builder or after it was returned from the database.
+func (m *ChannelMonitorHistoryMutation) ID() (id int64, exists bool) {
+ if m.id == nil {
+ return
+ }
+ return *m.id, true
+}
+
+// IDs queries the database and returns the entity ids that match the mutation's predicate.
+// That means, if the mutation is applied within a transaction with an isolation level such
+// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
+// or updated by the mutation.
+func (m *ChannelMonitorHistoryMutation) IDs(ctx context.Context) ([]int64, error) {
+ switch {
+ case m.op.Is(OpUpdateOne | OpDeleteOne):
+ id, exists := m.ID()
+ if exists {
+ return []int64{id}, nil
+ }
+ fallthrough
+ case m.op.Is(OpUpdate | OpDelete):
+ return m.Client().ChannelMonitorHistory.Query().Where(m.predicates...).IDs(ctx)
+ default:
+ return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
+ }
+}
+
+// SetDeletedAt sets the "deleted_at" field.
+func (m *ChannelMonitorHistoryMutation) SetDeletedAt(t time.Time) {
+ m.deleted_at = &t
+}
+
+// DeletedAt returns the value of the "deleted_at" field in the mutation.
+func (m *ChannelMonitorHistoryMutation) DeletedAt() (r time.Time, exists bool) {
+ v := m.deleted_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldDeletedAt returns the old "deleted_at" field's value of the ChannelMonitorHistory entity.
+// If the ChannelMonitorHistory object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorHistoryMutation) OldDeletedAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldDeletedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldDeletedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldDeletedAt: %w", err)
+ }
+ return oldValue.DeletedAt, nil
+}
+
+// ClearDeletedAt clears the value of the "deleted_at" field.
+func (m *ChannelMonitorHistoryMutation) ClearDeletedAt() {
+ m.deleted_at = nil
+ m.clearedFields[channelmonitorhistory.FieldDeletedAt] = struct{}{}
+}
+
+// DeletedAtCleared returns if the "deleted_at" field was cleared in this mutation.
+func (m *ChannelMonitorHistoryMutation) DeletedAtCleared() bool {
+ _, ok := m.clearedFields[channelmonitorhistory.FieldDeletedAt]
+ return ok
+}
+
+// ResetDeletedAt resets all changes to the "deleted_at" field.
+func (m *ChannelMonitorHistoryMutation) ResetDeletedAt() {
+ m.deleted_at = nil
+ delete(m.clearedFields, channelmonitorhistory.FieldDeletedAt)
+}
+
+// SetMonitorID sets the "monitor_id" field.
+func (m *ChannelMonitorHistoryMutation) SetMonitorID(i int64) {
+ m.monitor = &i
+}
+
+// MonitorID returns the value of the "monitor_id" field in the mutation.
+func (m *ChannelMonitorHistoryMutation) MonitorID() (r int64, exists bool) {
+ v := m.monitor
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldMonitorID returns the old "monitor_id" field's value of the ChannelMonitorHistory entity.
+// If the ChannelMonitorHistory object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorHistoryMutation) OldMonitorID(ctx context.Context) (v int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldMonitorID is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldMonitorID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldMonitorID: %w", err)
+ }
+ return oldValue.MonitorID, nil
+}
+
+// ResetMonitorID resets all changes to the "monitor_id" field.
+func (m *ChannelMonitorHistoryMutation) ResetMonitorID() {
+ m.monitor = nil
+}
+
+// SetModel sets the "model" field.
+func (m *ChannelMonitorHistoryMutation) SetModel(s string) {
+ m.model = &s
+}
+
+// Model returns the value of the "model" field in the mutation.
+func (m *ChannelMonitorHistoryMutation) Model() (r string, exists bool) {
+ v := m.model
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldModel returns the old "model" field's value of the ChannelMonitorHistory entity.
+// If the ChannelMonitorHistory object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorHistoryMutation) OldModel(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldModel is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldModel requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldModel: %w", err)
+ }
+ return oldValue.Model, nil
+}
+
+// ResetModel resets all changes to the "model" field.
+func (m *ChannelMonitorHistoryMutation) ResetModel() {
+ m.model = nil
+}
+
+// SetStatus sets the "status" field.
+func (m *ChannelMonitorHistoryMutation) SetStatus(c channelmonitorhistory.Status) {
+ m.status = &c
+}
+
+// Status returns the value of the "status" field in the mutation.
+func (m *ChannelMonitorHistoryMutation) Status() (r channelmonitorhistory.Status, exists bool) {
+ v := m.status
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldStatus returns the old "status" field's value of the ChannelMonitorHistory entity.
+// If the ChannelMonitorHistory object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorHistoryMutation) OldStatus(ctx context.Context) (v channelmonitorhistory.Status, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldStatus is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldStatus requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldStatus: %w", err)
+ }
+ return oldValue.Status, nil
+}
+
+// ResetStatus resets all changes to the "status" field.
+func (m *ChannelMonitorHistoryMutation) ResetStatus() {
+ m.status = nil
+}
+
+// SetLatencyMs sets the "latency_ms" field.
+func (m *ChannelMonitorHistoryMutation) SetLatencyMs(i int) {
+ m.latency_ms = &i
+ m.addlatency_ms = nil
+}
+
+// LatencyMs returns the value of the "latency_ms" field in the mutation.
+func (m *ChannelMonitorHistoryMutation) LatencyMs() (r int, exists bool) {
+ v := m.latency_ms
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldLatencyMs returns the old "latency_ms" field's value of the ChannelMonitorHistory entity.
+// If the ChannelMonitorHistory object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorHistoryMutation) OldLatencyMs(ctx context.Context) (v *int, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldLatencyMs is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldLatencyMs requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldLatencyMs: %w", err)
+ }
+ return oldValue.LatencyMs, nil
+}
+
+// AddLatencyMs adds i to the "latency_ms" field.
+func (m *ChannelMonitorHistoryMutation) AddLatencyMs(i int) {
+ if m.addlatency_ms != nil {
+ *m.addlatency_ms += i
+ } else {
+ m.addlatency_ms = &i
+ }
+}
+
+// AddedLatencyMs returns the value that was added to the "latency_ms" field in this mutation.
+func (m *ChannelMonitorHistoryMutation) AddedLatencyMs() (r int, exists bool) {
+ v := m.addlatency_ms
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ClearLatencyMs clears the value of the "latency_ms" field.
+func (m *ChannelMonitorHistoryMutation) ClearLatencyMs() {
+ m.latency_ms = nil
+ m.addlatency_ms = nil
+ m.clearedFields[channelmonitorhistory.FieldLatencyMs] = struct{}{}
+}
+
+// LatencyMsCleared returns if the "latency_ms" field was cleared in this mutation.
+func (m *ChannelMonitorHistoryMutation) LatencyMsCleared() bool {
+ _, ok := m.clearedFields[channelmonitorhistory.FieldLatencyMs]
+ return ok
+}
+
+// ResetLatencyMs resets all changes to the "latency_ms" field.
+func (m *ChannelMonitorHistoryMutation) ResetLatencyMs() {
+ m.latency_ms = nil
+ m.addlatency_ms = nil
+ delete(m.clearedFields, channelmonitorhistory.FieldLatencyMs)
+}
+
+// SetPingLatencyMs sets the "ping_latency_ms" field.
+func (m *ChannelMonitorHistoryMutation) SetPingLatencyMs(i int) {
+ m.ping_latency_ms = &i
+ m.addping_latency_ms = nil
+}
+
+// PingLatencyMs returns the value of the "ping_latency_ms" field in the mutation.
+func (m *ChannelMonitorHistoryMutation) PingLatencyMs() (r int, exists bool) {
+ v := m.ping_latency_ms
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldPingLatencyMs returns the old "ping_latency_ms" field's value of the ChannelMonitorHistory entity.
+// If the ChannelMonitorHistory object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorHistoryMutation) OldPingLatencyMs(ctx context.Context) (v *int, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldPingLatencyMs is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldPingLatencyMs requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldPingLatencyMs: %w", err)
+ }
+ return oldValue.PingLatencyMs, nil
+}
+
+// AddPingLatencyMs adds i to the "ping_latency_ms" field.
+func (m *ChannelMonitorHistoryMutation) AddPingLatencyMs(i int) {
+ if m.addping_latency_ms != nil {
+ *m.addping_latency_ms += i
+ } else {
+ m.addping_latency_ms = &i
+ }
+}
+
+// AddedPingLatencyMs returns the value that was added to the "ping_latency_ms" field in this mutation.
+func (m *ChannelMonitorHistoryMutation) AddedPingLatencyMs() (r int, exists bool) {
+ v := m.addping_latency_ms
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ClearPingLatencyMs clears the value of the "ping_latency_ms" field.
+func (m *ChannelMonitorHistoryMutation) ClearPingLatencyMs() {
+ m.ping_latency_ms = nil
+ m.addping_latency_ms = nil
+ m.clearedFields[channelmonitorhistory.FieldPingLatencyMs] = struct{}{}
+}
+
+// PingLatencyMsCleared returns if the "ping_latency_ms" field was cleared in this mutation.
+func (m *ChannelMonitorHistoryMutation) PingLatencyMsCleared() bool {
+ _, ok := m.clearedFields[channelmonitorhistory.FieldPingLatencyMs]
+ return ok
+}
+
+// ResetPingLatencyMs resets all changes to the "ping_latency_ms" field.
+func (m *ChannelMonitorHistoryMutation) ResetPingLatencyMs() {
+ m.ping_latency_ms = nil
+ m.addping_latency_ms = nil
+ delete(m.clearedFields, channelmonitorhistory.FieldPingLatencyMs)
+}
+
+// SetMessage sets the "message" field.
+func (m *ChannelMonitorHistoryMutation) SetMessage(s string) {
+ m.message = &s
+}
+
+// Message returns the value of the "message" field in the mutation.
+func (m *ChannelMonitorHistoryMutation) Message() (r string, exists bool) {
+ v := m.message
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldMessage returns the old "message" field's value of the ChannelMonitorHistory entity.
+// If the ChannelMonitorHistory object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorHistoryMutation) OldMessage(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldMessage is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldMessage requires an ID field in the mutation")
@@ -10445,7 +12071,10 @@ func (m *ChannelMonitorHistoryMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *ChannelMonitorHistoryMutation) Fields() []string {
- fields := make([]string, 0, 7)
+ fields := make([]string, 0, 8)
+ if m.deleted_at != nil {
+ fields = append(fields, channelmonitorhistory.FieldDeletedAt)
+ }
if m.monitor != nil {
fields = append(fields, channelmonitorhistory.FieldMonitorID)
}
@@ -10475,6 +12104,8 @@ func (m *ChannelMonitorHistoryMutation) Fields() []string {
// schema.
func (m *ChannelMonitorHistoryMutation) Field(name string) (ent.Value, bool) {
switch name {
+ case channelmonitorhistory.FieldDeletedAt:
+ return m.DeletedAt()
case channelmonitorhistory.FieldMonitorID:
return m.MonitorID()
case channelmonitorhistory.FieldModel:
@@ -10498,6 +12129,8 @@ func (m *ChannelMonitorHistoryMutation) Field(name string) (ent.Value, bool) {
// database failed.
func (m *ChannelMonitorHistoryMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
switch name {
+ case channelmonitorhistory.FieldDeletedAt:
+ return m.OldDeletedAt(ctx)
case channelmonitorhistory.FieldMonitorID:
return m.OldMonitorID(ctx)
case channelmonitorhistory.FieldModel:
@@ -10521,6 +12154,13 @@ func (m *ChannelMonitorHistoryMutation) OldField(ctx context.Context, name strin
// type.
func (m *ChannelMonitorHistoryMutation) SetField(name string, value ent.Value) error {
switch name {
+ case channelmonitorhistory.FieldDeletedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetDeletedAt(v)
+ return nil
case channelmonitorhistory.FieldMonitorID:
v, ok := value.(int64)
if !ok {
@@ -10627,6 +12267,9 @@ func (m *ChannelMonitorHistoryMutation) AddField(name string, value ent.Value) e
// mutation.
func (m *ChannelMonitorHistoryMutation) ClearedFields() []string {
var fields []string
+ if m.FieldCleared(channelmonitorhistory.FieldDeletedAt) {
+ fields = append(fields, channelmonitorhistory.FieldDeletedAt)
+ }
if m.FieldCleared(channelmonitorhistory.FieldLatencyMs) {
fields = append(fields, channelmonitorhistory.FieldLatencyMs)
}
@@ -10650,6 +12293,9 @@ func (m *ChannelMonitorHistoryMutation) FieldCleared(name string) bool {
// error if the field is not defined in the schema.
func (m *ChannelMonitorHistoryMutation) ClearField(name string) error {
switch name {
+ case channelmonitorhistory.FieldDeletedAt:
+ m.ClearDeletedAt()
+ return nil
case channelmonitorhistory.FieldLatencyMs:
m.ClearLatencyMs()
return nil
@@ -10667,6 +12313,9 @@ func (m *ChannelMonitorHistoryMutation) ClearField(name string) error {
// It returns an error if the field is not defined in the schema.
func (m *ChannelMonitorHistoryMutation) ResetField(name string) error {
switch name {
+ case channelmonitorhistory.FieldDeletedAt:
+ m.ResetDeletedAt()
+ return nil
case channelmonitorhistory.FieldMonitorID:
m.ResetMonitorID()
return nil
diff --git a/backend/ent/predicate/predicate.go b/backend/ent/predicate/predicate.go
index 256b5f2a..adb9a085 100644
--- a/backend/ent/predicate/predicate.go
+++ b/backend/ent/predicate/predicate.go
@@ -30,6 +30,9 @@ type AuthIdentityChannel func(*sql.Selector)
// ChannelMonitor is the predicate function for channelmonitor builders.
type ChannelMonitor func(*sql.Selector)
+// ChannelMonitorDailyRollup is the predicate function for channelmonitordailyrollup builders.
+type ChannelMonitorDailyRollup func(*sql.Selector)
+
// ChannelMonitorHistory is the predicate function for channelmonitorhistory builders.
type ChannelMonitorHistory func(*sql.Selector)
diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go
index 0183f377..25076444 100644
--- a/backend/ent/runtime/runtime.go
+++ b/backend/ent/runtime/runtime.go
@@ -13,6 +13,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
"github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group"
@@ -520,6 +521,82 @@ func init() {
channelmonitorDescIntervalSeconds := channelmonitorFields[8].Descriptor()
// channelmonitor.IntervalSecondsValidator is a validator for the "interval_seconds" field. It is called by the builders before save.
channelmonitor.IntervalSecondsValidator = channelmonitorDescIntervalSeconds.Validators[0].(func(int) error)
+ channelmonitordailyrollupMixin := schema.ChannelMonitorDailyRollup{}.Mixin()
+ channelmonitordailyrollupMixinHooks0 := channelmonitordailyrollupMixin[0].Hooks()
+ channelmonitordailyrollup.Hooks[0] = channelmonitordailyrollupMixinHooks0[0]
+ channelmonitordailyrollupMixinInters0 := channelmonitordailyrollupMixin[0].Interceptors()
+ channelmonitordailyrollup.Interceptors[0] = channelmonitordailyrollupMixinInters0[0]
+ channelmonitordailyrollupFields := schema.ChannelMonitorDailyRollup{}.Fields()
+ _ = channelmonitordailyrollupFields
+ // channelmonitordailyrollupDescModel is the schema descriptor for model field.
+ channelmonitordailyrollupDescModel := channelmonitordailyrollupFields[1].Descriptor()
+ // channelmonitordailyrollup.ModelValidator is a validator for the "model" field. It is called by the builders before save.
+ channelmonitordailyrollup.ModelValidator = func() func(string) error {
+ validators := channelmonitordailyrollupDescModel.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ }
+ return func(model string) error {
+ for _, fn := range fns {
+ if err := fn(model); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // channelmonitordailyrollupDescTotalChecks is the schema descriptor for total_checks field.
+ channelmonitordailyrollupDescTotalChecks := channelmonitordailyrollupFields[3].Descriptor()
+ // channelmonitordailyrollup.DefaultTotalChecks holds the default value on creation for the total_checks field.
+ channelmonitordailyrollup.DefaultTotalChecks = channelmonitordailyrollupDescTotalChecks.Default.(int)
+ // channelmonitordailyrollupDescOkCount is the schema descriptor for ok_count field.
+ channelmonitordailyrollupDescOkCount := channelmonitordailyrollupFields[4].Descriptor()
+ // channelmonitordailyrollup.DefaultOkCount holds the default value on creation for the ok_count field.
+ channelmonitordailyrollup.DefaultOkCount = channelmonitordailyrollupDescOkCount.Default.(int)
+ // channelmonitordailyrollupDescOperationalCount is the schema descriptor for operational_count field.
+ channelmonitordailyrollupDescOperationalCount := channelmonitordailyrollupFields[5].Descriptor()
+ // channelmonitordailyrollup.DefaultOperationalCount holds the default value on creation for the operational_count field.
+ channelmonitordailyrollup.DefaultOperationalCount = channelmonitordailyrollupDescOperationalCount.Default.(int)
+ // channelmonitordailyrollupDescDegradedCount is the schema descriptor for degraded_count field.
+ channelmonitordailyrollupDescDegradedCount := channelmonitordailyrollupFields[6].Descriptor()
+ // channelmonitordailyrollup.DefaultDegradedCount holds the default value on creation for the degraded_count field.
+ channelmonitordailyrollup.DefaultDegradedCount = channelmonitordailyrollupDescDegradedCount.Default.(int)
+ // channelmonitordailyrollupDescFailedCount is the schema descriptor for failed_count field.
+ channelmonitordailyrollupDescFailedCount := channelmonitordailyrollupFields[7].Descriptor()
+ // channelmonitordailyrollup.DefaultFailedCount holds the default value on creation for the failed_count field.
+ channelmonitordailyrollup.DefaultFailedCount = channelmonitordailyrollupDescFailedCount.Default.(int)
+ // channelmonitordailyrollupDescErrorCount is the schema descriptor for error_count field.
+ channelmonitordailyrollupDescErrorCount := channelmonitordailyrollupFields[8].Descriptor()
+ // channelmonitordailyrollup.DefaultErrorCount holds the default value on creation for the error_count field.
+ channelmonitordailyrollup.DefaultErrorCount = channelmonitordailyrollupDescErrorCount.Default.(int)
+ // channelmonitordailyrollupDescSumLatencyMs is the schema descriptor for sum_latency_ms field.
+ channelmonitordailyrollupDescSumLatencyMs := channelmonitordailyrollupFields[9].Descriptor()
+ // channelmonitordailyrollup.DefaultSumLatencyMs holds the default value on creation for the sum_latency_ms field.
+ channelmonitordailyrollup.DefaultSumLatencyMs = channelmonitordailyrollupDescSumLatencyMs.Default.(int64)
+ // channelmonitordailyrollupDescCountLatency is the schema descriptor for count_latency field.
+ channelmonitordailyrollupDescCountLatency := channelmonitordailyrollupFields[10].Descriptor()
+ // channelmonitordailyrollup.DefaultCountLatency holds the default value on creation for the count_latency field.
+ channelmonitordailyrollup.DefaultCountLatency = channelmonitordailyrollupDescCountLatency.Default.(int)
+ // channelmonitordailyrollupDescSumPingLatencyMs is the schema descriptor for sum_ping_latency_ms field.
+ channelmonitordailyrollupDescSumPingLatencyMs := channelmonitordailyrollupFields[11].Descriptor()
+ // channelmonitordailyrollup.DefaultSumPingLatencyMs holds the default value on creation for the sum_ping_latency_ms field.
+ channelmonitordailyrollup.DefaultSumPingLatencyMs = channelmonitordailyrollupDescSumPingLatencyMs.Default.(int64)
+ // channelmonitordailyrollupDescCountPingLatency is the schema descriptor for count_ping_latency field.
+ channelmonitordailyrollupDescCountPingLatency := channelmonitordailyrollupFields[12].Descriptor()
+ // channelmonitordailyrollup.DefaultCountPingLatency holds the default value on creation for the count_ping_latency field.
+ channelmonitordailyrollup.DefaultCountPingLatency = channelmonitordailyrollupDescCountPingLatency.Default.(int)
+ // channelmonitordailyrollupDescComputedAt is the schema descriptor for computed_at field.
+ channelmonitordailyrollupDescComputedAt := channelmonitordailyrollupFields[13].Descriptor()
+ // channelmonitordailyrollup.DefaultComputedAt holds the default value on creation for the computed_at field.
+ channelmonitordailyrollup.DefaultComputedAt = channelmonitordailyrollupDescComputedAt.Default.(func() time.Time)
+ // channelmonitordailyrollup.UpdateDefaultComputedAt holds the default value on update for the computed_at field.
+ channelmonitordailyrollup.UpdateDefaultComputedAt = channelmonitordailyrollupDescComputedAt.UpdateDefault.(func() time.Time)
+ channelmonitorhistoryMixin := schema.ChannelMonitorHistory{}.Mixin()
+ channelmonitorhistoryMixinHooks0 := channelmonitorhistoryMixin[0].Hooks()
+ channelmonitorhistory.Hooks[0] = channelmonitorhistoryMixinHooks0[0]
+ channelmonitorhistoryMixinInters0 := channelmonitorhistoryMixin[0].Interceptors()
+ channelmonitorhistory.Interceptors[0] = channelmonitorhistoryMixinInters0[0]
channelmonitorhistoryFields := schema.ChannelMonitorHistory{}.Fields()
_ = channelmonitorhistoryFields
// channelmonitorhistoryDescModel is the schema descriptor for model field.
diff --git a/backend/ent/schema/channel_monitor.go b/backend/ent/schema/channel_monitor.go
index 3fa17319..f6a6578d 100644
--- a/backend/ent/schema/channel_monitor.go
+++ b/backend/ent/schema/channel_monitor.go
@@ -69,6 +69,8 @@ func (ChannelMonitor) Edges() []ent.Edge {
return []ent.Edge{
edge.To("history", ChannelMonitorHistory.Type).
Annotations(entsql.OnDelete(entsql.Cascade)),
+ edge.To("daily_rollups", ChannelMonitorDailyRollup.Type).
+ Annotations(entsql.OnDelete(entsql.Cascade)),
}
}
diff --git a/backend/ent/schema/channel_monitor_daily_rollup.go b/backend/ent/schema/channel_monitor_daily_rollup.go
new file mode 100644
index 00000000..574a28d9
--- /dev/null
+++ b/backend/ent/schema/channel_monitor_daily_rollup.go
@@ -0,0 +1,73 @@
+package schema
+
+import (
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/entsql"
+ "entgo.io/ent/schema"
+ "entgo.io/ent/schema/edge"
+ "entgo.io/ent/schema/field"
+ "entgo.io/ent/schema/index"
+
+ "github.com/Wei-Shaw/sub2api/ent/schema/mixins"
+)
+
+// ChannelMonitorDailyRollup 按 (monitor_id, model, bucket_date) 维度聚合的渠道监控日统计。
+// 每天的明细被收敛为一行(保留 status 分布 + 延迟和),用于 7d/15d/30d 窗口的可用率
+// 加权计算(avg_latency = sum_latency_ms / count_latency;availability = ok_count / total_checks)。
+type ChannelMonitorDailyRollup struct {
+ ent.Schema
+}
+
+func (ChannelMonitorDailyRollup) Annotations() []schema.Annotation {
+ return []schema.Annotation{
+ entsql.Annotation{Table: "channel_monitor_daily_rollups"},
+ }
+}
+
+func (ChannelMonitorDailyRollup) Mixin() []ent.Mixin {
+ return []ent.Mixin{
+ mixins.SoftDeleteMixin{},
+ }
+}
+
+func (ChannelMonitorDailyRollup) Fields() []ent.Field {
+ return []ent.Field{
+ field.Int64("monitor_id"),
+ field.String("model").
+ NotEmpty().
+ MaxLen(200),
+ field.Time("bucket_date").
+ SchemaType(map[string]string{dialect.Postgres: "date"}),
+ field.Int("total_checks").Default(0),
+ field.Int("ok_count").Default(0),
+ field.Int("operational_count").Default(0),
+ field.Int("degraded_count").Default(0),
+ field.Int("failed_count").Default(0),
+ field.Int("error_count").Default(0),
+ field.Int64("sum_latency_ms").Default(0),
+ field.Int("count_latency").Default(0),
+ field.Int64("sum_ping_latency_ms").Default(0),
+ field.Int("count_ping_latency").Default(0),
+ field.Time("computed_at").Default(time.Now).UpdateDefault(time.Now),
+ }
+}
+
+func (ChannelMonitorDailyRollup) Edges() []ent.Edge {
+ return []ent.Edge{
+ edge.From("monitor", ChannelMonitor.Type).
+ Ref("daily_rollups").
+ Field("monitor_id").
+ Unique().
+ Required(),
+ }
+}
+
+func (ChannelMonitorDailyRollup) Indexes() []ent.Index {
+ return []ent.Index{
+ index.Fields("monitor_id", "model", "bucket_date").Unique(),
+ index.Fields("bucket_date"),
+ }
+}
diff --git a/backend/ent/schema/channel_monitor_history.go b/backend/ent/schema/channel_monitor_history.go
index 50352016..ec54b34f 100644
--- a/backend/ent/schema/channel_monitor_history.go
+++ b/backend/ent/schema/channel_monitor_history.go
@@ -9,10 +9,13 @@ import (
"entgo.io/ent/schema/edge"
"entgo.io/ent/schema/field"
"entgo.io/ent/schema/index"
+
+ "github.com/Wei-Shaw/sub2api/ent/schema/mixins"
)
// ChannelMonitorHistory holds the schema definition for the ChannelMonitorHistory entity.
-// 渠道监控历史:每次检测每个模型一行记录,由调度器写入,定期清理 30 天前的旧数据。
+// 渠道监控历史:每次检测每个模型一行记录。明细只保留 1 天,超过 1 天的数据被聚合到
+// channel_monitor_daily_rollups 后软删(deleted_at),由后续懒清理任务物理移除。
type ChannelMonitorHistory struct {
ent.Schema
}
@@ -23,6 +26,12 @@ func (ChannelMonitorHistory) Annotations() []schema.Annotation {
}
}
+func (ChannelMonitorHistory) Mixin() []ent.Mixin {
+ return []ent.Mixin{
+ mixins.SoftDeleteMixin{},
+ }
+}
+
func (ChannelMonitorHistory) Fields() []ent.Field {
return []ent.Field{
field.Int64("monitor_id"),
diff --git a/backend/ent/tx.go b/backend/ent/tx.go
index f937270f..0e65a940 100644
--- a/backend/ent/tx.go
+++ b/backend/ent/tx.go
@@ -30,6 +30,8 @@ type Tx struct {
AuthIdentityChannel *AuthIdentityChannelClient
// ChannelMonitor is the client for interacting with the ChannelMonitor builders.
ChannelMonitor *ChannelMonitorClient
+ // ChannelMonitorDailyRollup is the client for interacting with the ChannelMonitorDailyRollup builders.
+ ChannelMonitorDailyRollup *ChannelMonitorDailyRollupClient
// ChannelMonitorHistory is the client for interacting with the ChannelMonitorHistory builders.
ChannelMonitorHistory *ChannelMonitorHistoryClient
// ErrorPassthroughRule is the client for interacting with the ErrorPassthroughRule builders.
@@ -217,6 +219,7 @@ func (tx *Tx) init() {
tx.AuthIdentity = NewAuthIdentityClient(tx.config)
tx.AuthIdentityChannel = NewAuthIdentityChannelClient(tx.config)
tx.ChannelMonitor = NewChannelMonitorClient(tx.config)
+ tx.ChannelMonitorDailyRollup = NewChannelMonitorDailyRollupClient(tx.config)
tx.ChannelMonitorHistory = NewChannelMonitorHistoryClient(tx.config)
tx.ErrorPassthroughRule = NewErrorPassthroughRuleClient(tx.config)
tx.Group = NewGroupClient(tx.config)
diff --git a/backend/internal/repository/channel_monitor_repo.go b/backend/internal/repository/channel_monitor_repo.go
index cf5e1a93..badbdbca 100644
--- a/backend/internal/repository/channel_monitor_repo.go
+++ b/backend/internal/repository/channel_monitor_repo.go
@@ -9,6 +9,7 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
"github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
@@ -246,6 +247,7 @@ func (r *channelMonitorRepository) ListLatestPerModel(ctx context.Context, monit
model, status, latency_ms, ping_latency_ms, checked_at
FROM channel_monitor_histories
WHERE monitor_id = $1
+ AND deleted_at IS NULL
ORDER BY model, checked_at DESC
`
rows, err := r.db.QueryContext(ctx, q, monitorID)
@@ -280,23 +282,48 @@ func assignNullInt(dst **int, n sql.NullInt64) {
// ComputeAvailability 计算指定窗口内每个模型的可用率与平均延迟。
// "可用" = status IN (operational, degraded)。
+//
+// 数据来源:明细表只保留 1 天;窗口前其余天数走聚合表。
+// - raw = 今天(CURRENT_DATE 起)的未软删明细,按 model 累加
+// - rollup = [CURRENT_DATE - windowDays, CURRENT_DATE) 区间的聚合行
+//
+// 总窗口为 "今天 + 过去 windowDays 天",比 windowDays 字面值大 1 天,但因为聚合
+// 是按整 UTC 日切的,这是聚合化无法避免的精度损失,且偏宽不偏窄(数据更全)。
func (r *channelMonitorRepository) ComputeAvailability(ctx context.Context, monitorID int64, windowDays int) ([]*service.ChannelMonitorAvailability, error) {
if windowDays <= 0 {
windowDays = 7
}
const q = `
- SELECT
- model,
- COUNT(*) AS total_checks,
- COUNT(*) FILTER (WHERE status IN ('operational','degraded')) AS ok_checks,
- AVG(latency_ms) FILTER (WHERE latency_ms IS NOT NULL) AS avg_latency_ms
- FROM channel_monitor_histories
- WHERE monitor_id = $1
- AND checked_at >= $2
+ WITH raw AS (
+ SELECT model,
+ COUNT(*) AS total_checks,
+ COUNT(*) FILTER (WHERE status IN ('operational','degraded')) AS ok_count,
+ COALESCE(SUM(latency_ms) FILTER (WHERE latency_ms IS NOT NULL), 0) AS sum_latency_ms,
+ COUNT(latency_ms) AS count_latency
+ FROM channel_monitor_histories
+ WHERE monitor_id = $1
+ AND deleted_at IS NULL
+ AND checked_at >= CURRENT_DATE
+ GROUP BY model
+ ),
+ rollup AS (
+ SELECT model, total_checks, ok_count, sum_latency_ms, count_latency
+ FROM channel_monitor_daily_rollups
+ WHERE monitor_id = $1
+ AND deleted_at IS NULL
+ AND bucket_date >= (CURRENT_DATE - $2::int)
+ AND bucket_date < CURRENT_DATE
+ )
+ SELECT model,
+ SUM(total_checks) AS total,
+ SUM(ok_count) AS ok,
+ CASE WHEN SUM(count_latency) > 0
+ THEN SUM(sum_latency_ms)::float8 / SUM(count_latency)
+ ELSE NULL END AS avg_latency_ms
+ FROM (SELECT * FROM raw UNION ALL SELECT * FROM rollup) combined
GROUP BY model
`
- from := time.Now().AddDate(0, 0, -windowDays)
- rows, err := r.db.QueryContext(ctx, q, monitorID, from)
+ rows, err := r.db.QueryContext(ctx, q, monitorID, windowDays)
if err != nil {
return nil, fmt.Errorf("query availability: %w", err)
}
@@ -349,6 +376,7 @@ func (r *channelMonitorRepository) ListLatestForMonitorIDs(ctx context.Context,
monitor_id, model, status, latency_ms, ping_latency_ms, checked_at
FROM channel_monitor_histories
WHERE monitor_id = ANY($1)
+ AND deleted_at IS NULL
ORDER BY monitor_id, model, checked_at DESC
`
rows, err := r.db.QueryContext(ctx, q, pq.Array(ids))
@@ -409,6 +437,7 @@ func (r *channelMonitorRepository) ListRecentHistoryForMonitors(
FROM channel_monitor_histories h
JOIN targets t
ON t.monitor_id = h.monitor_id AND t.model = h.model
+ WHERE h.deleted_at IS NULL
)
SELECT monitor_id, status, latency_ms, ping_latency_ms, checked_at
FROM ranked
@@ -476,6 +505,7 @@ func clampTimelineLimit(n int) int {
}
// ComputeAvailabilityForMonitors 一次性计算多个监控在某个窗口内的每模型可用率与平均延迟。
+// 与单 monitor 版本同构:明细只覆盖今天,更早走聚合表 UNION 合并。
func (r *channelMonitorRepository) ComputeAvailabilityForMonitors(ctx context.Context, ids []int64, windowDays int) (map[int64][]*service.ChannelMonitorAvailability, error) {
out := make(map[int64][]*service.ChannelMonitorAvailability, len(ids))
if len(ids) == 0 {
@@ -485,19 +515,38 @@ func (r *channelMonitorRepository) ComputeAvailabilityForMonitors(ctx context.Co
windowDays = 7
}
const q = `
- SELECT
- monitor_id,
- model,
- COUNT(*) AS total_checks,
- COUNT(*) FILTER (WHERE status IN ('operational','degraded')) AS ok_checks,
- AVG(latency_ms) FILTER (WHERE latency_ms IS NOT NULL) AS avg_latency_ms
- FROM channel_monitor_histories
- WHERE monitor_id = ANY($1)
- AND checked_at >= $2
+ WITH raw AS (
+ SELECT monitor_id,
+ model,
+ COUNT(*) AS total_checks,
+ COUNT(*) FILTER (WHERE status IN ('operational','degraded')) AS ok_count,
+ COALESCE(SUM(latency_ms) FILTER (WHERE latency_ms IS NOT NULL), 0) AS sum_latency_ms,
+ COUNT(latency_ms) AS count_latency
+ FROM channel_monitor_histories
+ WHERE monitor_id = ANY($1)
+ AND deleted_at IS NULL
+ AND checked_at >= CURRENT_DATE
+ GROUP BY monitor_id, model
+ ),
+ rollup AS (
+ SELECT monitor_id, model, total_checks, ok_count, sum_latency_ms, count_latency
+ FROM channel_monitor_daily_rollups
+ WHERE monitor_id = ANY($1)
+ AND deleted_at IS NULL
+ AND bucket_date >= (CURRENT_DATE - $2::int)
+ AND bucket_date < CURRENT_DATE
+ )
+ SELECT monitor_id,
+ model,
+ SUM(total_checks) AS total,
+ SUM(ok_count) AS ok,
+ CASE WHEN SUM(count_latency) > 0
+ THEN SUM(sum_latency_ms)::float8 / SUM(count_latency)
+ ELSE NULL END AS avg_latency_ms
+ FROM (SELECT * FROM raw UNION ALL SELECT * FROM rollup) combined
GROUP BY monitor_id, model
`
- from := time.Now().AddDate(0, 0, -windowDays)
- rows, err := r.db.QueryContext(ctx, q, pq.Array(ids), from)
+ rows, err := r.db.QueryContext(ctx, q, pq.Array(ids), windowDays)
if err != nil {
return nil, fmt.Errorf("query availability batch: %w", err)
}
@@ -521,6 +570,116 @@ func (r *channelMonitorRepository) ComputeAvailabilityForMonitors(ctx context.Co
return out, nil
}
+// ---------- 聚合维护 ----------
+
+// UpsertDailyRollupsFor 把 targetDate 当天([targetDate, targetDate+1d))未软删的明细
+// 按 (monitor_id, model, bucket_date) 聚合写入 channel_monitor_daily_rollups。
+// - 用 ON CONFLICT (monitor_id, model, bucket_date) DO UPDATE 实现幂等回填,
+// 重复执行只会用最新统计覆盖;
+// - 同时把 deleted_at 重置为 NULL,避免历史误删后聚合行被持续过滤掉;
+// - $1::date 让 PG 自动把入参 truncate 到 UTC 日期,调用方不需要预处理 targetDate。
+func (r *channelMonitorRepository) UpsertDailyRollupsFor(ctx context.Context, targetDate time.Time) (int64, error) {
+ const q = `
+ INSERT INTO channel_monitor_daily_rollups (
+ monitor_id, model, bucket_date,
+ total_checks, ok_count,
+ operational_count, degraded_count, failed_count, error_count,
+ sum_latency_ms, count_latency,
+ sum_ping_latency_ms, count_ping_latency,
+ computed_at
+ )
+ SELECT
+ monitor_id,
+ model,
+ $1::date AS bucket_date,
+ COUNT(*) AS total_checks,
+ COUNT(*) FILTER (WHERE status IN ('operational','degraded')) AS ok_count,
+ COUNT(*) FILTER (WHERE status = 'operational') AS operational_count,
+ COUNT(*) FILTER (WHERE status = 'degraded') AS degraded_count,
+ COUNT(*) FILTER (WHERE status = 'failed') AS failed_count,
+ COUNT(*) FILTER (WHERE status = 'error') AS error_count,
+ COALESCE(SUM(latency_ms) FILTER (WHERE latency_ms IS NOT NULL), 0) AS sum_latency_ms,
+ COUNT(latency_ms) AS count_latency,
+ COALESCE(SUM(ping_latency_ms) FILTER (WHERE ping_latency_ms IS NOT NULL), 0) AS sum_ping_latency_ms,
+ COUNT(ping_latency_ms) AS count_ping_latency,
+ NOW()
+ FROM channel_monitor_histories
+ WHERE deleted_at IS NULL
+ AND checked_at >= $1::date
+ AND checked_at < ($1::date + INTERVAL '1 day')
+ GROUP BY monitor_id, model
+ ON CONFLICT (monitor_id, model, bucket_date) DO UPDATE SET
+ total_checks = EXCLUDED.total_checks,
+ ok_count = EXCLUDED.ok_count,
+ operational_count = EXCLUDED.operational_count,
+ degraded_count = EXCLUDED.degraded_count,
+ failed_count = EXCLUDED.failed_count,
+ error_count = EXCLUDED.error_count,
+ sum_latency_ms = EXCLUDED.sum_latency_ms,
+ count_latency = EXCLUDED.count_latency,
+ sum_ping_latency_ms = EXCLUDED.sum_ping_latency_ms,
+ count_ping_latency = EXCLUDED.count_ping_latency,
+ computed_at = NOW(),
+ deleted_at = NULL
+ `
+ res, err := r.db.ExecContext(ctx, q, targetDate)
+ if err != nil {
+ return 0, fmt.Errorf("upsert daily rollups for %s: %w", targetDate.Format("2006-01-02"), err)
+ }
+ n, err := res.RowsAffected()
+ if err != nil {
+ return 0, fmt.Errorf("rows affected (upsert rollups): %w", err)
+ }
+ return n, nil
+}
+
+// DeleteRollupsBefore 软删 bucket_date < beforeDate 的聚合行。
+// 走 ent client,利用 SoftDeleteMixin 把 DELETE 自动改写为 UPDATE deleted_at = NOW()。
+func (r *channelMonitorRepository) DeleteRollupsBefore(ctx context.Context, beforeDate time.Time) (int64, error) {
+ client := clientFromContext(ctx, r.client)
+ n, err := client.ChannelMonitorDailyRollup.Delete().
+ Where(channelmonitordailyrollup.BucketDateLT(beforeDate)).
+ Exec(ctx)
+ if err != nil {
+ return 0, fmt.Errorf("delete rollups before: %w", err)
+ }
+ return int64(n), nil
+}
+
+// LoadAggregationWatermark 读 watermark 表(id=1)。
+// watermark 表不是 ent schema(只有一行),直接走原生 SQL。
+// - 行不存在或 last_aggregated_date IS NULL:返回 (nil, nil),由调用方决定首次回填策略
+func (r *channelMonitorRepository) LoadAggregationWatermark(ctx context.Context) (*time.Time, error) {
+ const q = `SELECT last_aggregated_date FROM channel_monitor_aggregation_watermark WHERE id = 1`
+ var t sql.NullTime
+ if err := r.db.QueryRowContext(ctx, q).Scan(&t); err != nil {
+ if err == sql.ErrNoRows {
+ return nil, nil
+ }
+ return nil, fmt.Errorf("load aggregation watermark: %w", err)
+ }
+ if !t.Valid {
+ return nil, nil
+ }
+ return &t.Time, nil
+}
+
+// UpdateAggregationWatermark 更新 watermark(UPSERT 到 id=1)。
+// $1::date 让 PG 把入参 truncate 到 UTC 日期,与 last_aggregated_date 列的 DATE 类型一致。
+func (r *channelMonitorRepository) UpdateAggregationWatermark(ctx context.Context, date time.Time) error {
+ const q = `
+ INSERT INTO channel_monitor_aggregation_watermark (id, last_aggregated_date, updated_at)
+ VALUES (1, $1::date, NOW())
+ ON CONFLICT (id) DO UPDATE SET
+ last_aggregated_date = EXCLUDED.last_aggregated_date,
+ updated_at = NOW()
+ `
+ if _, err := r.db.ExecContext(ctx, q, date); err != nil {
+ return fmt.Errorf("update aggregation watermark: %w", err)
+ }
+ return nil
+}
+
// ---------- helpers ----------
func entToServiceMonitor(row *dbent.ChannelMonitor) *service.ChannelMonitor {
diff --git a/backend/internal/service/channel_monitor_const.go b/backend/internal/service/channel_monitor_const.go
index 7255e4be..b61f3bdd 100644
--- a/backend/internal/service/channel_monitor_const.go
+++ b/backend/internal/service/channel_monitor_const.go
@@ -15,8 +15,16 @@ const (
monitorPingTimeout = 8 * time.Second
// monitorDegradedThreshold 主请求成功但耗时超过该阈值视为 degraded。
monitorDegradedThreshold = 6 * time.Second
- // monitorHistoryRetentionDays 历史保留天数(每天清理一次)。
- monitorHistoryRetentionDays = 30
+ // monitorHistoryRetentionDays 明细历史保留天数。
+ // 明细只保留 1 天,超出由 SoftDeleteMixin 软删;
+ // 维护任务每天凌晨跑(由 OpsCleanupService 统一调度)。
+ monitorHistoryRetentionDays = 1
+ // monitorRollupRetentionDays 日聚合保留天数。
+ // 日聚合行由 RunDailyMaintenance 在超过该窗口后软删。
+ monitorRollupRetentionDays = 30
+ // monitorMaintenanceMaxDaysPerRun 单次维护任务最多聚合的天数。
+ // 用于限制首次上线回填(30 天)+ 少量余量,避免长事务。
+ monitorMaintenanceMaxDaysPerRun = 35
// monitorWorkerConcurrency 调度器并发执行的监控数(pond 池容量)。
monitorWorkerConcurrency = 5
// monitorTickerInterval 调度器扫描"到期监控"的间隔。
@@ -55,11 +63,6 @@ const (
monitorAvailability15Days = 15
monitorAvailability30Days = 30
- // monitorCleanupCheckInterval 历史清理调度器的检查频率(每小时检查"是否到 03:00")。
- monitorCleanupCheckInterval = time.Hour
- // monitorCleanupHour 凌晨 3 点执行历史清理。
- monitorCleanupHour = 3
-
// MonitorHistoryDefaultLimit 历史查询默认返回条数(handler 层共享)。
MonitorHistoryDefaultLimit = 100
// MonitorHistoryMaxLimit 历史查询最大返回条数(handler 层共享)。
@@ -82,10 +85,6 @@ const (
monitorListDueTimeout = 10 * time.Second
// monitorRunOneBuffer runOne 的总超时缓冲(除请求超时与 ping 超时外的额外裕量)。
monitorRunOneBuffer = 10 * time.Second
- // monitorCleanupTimeout 历史清理任务的总超时。
- monitorCleanupTimeout = 30 * time.Second
- // monitorCleanupDayLayout 历史清理用于"今日是否已跑过"判定的日期格式。
- monitorCleanupDayLayout = "2006-01-02"
// monitorIdleConnTimeout HTTP transport 空闲连接关闭超时。
monitorIdleConnTimeout = 30 * time.Second
diff --git a/backend/internal/service/channel_monitor_runner.go b/backend/internal/service/channel_monitor_runner.go
index 4655e6df..21dca8ab 100644
--- a/backend/internal/service/channel_monitor_runner.go
+++ b/backend/internal/service/channel_monitor_runner.go
@@ -14,10 +14,10 @@ import (
// 职责:
// - 每 monitorTickerInterval 扫描一次"到期需要检测"的监控
// - 通过 pond 池(容量 monitorWorkerConcurrency)异步执行检测
-// - 每小时检查一次时钟,到 monitorCleanupHour 点时执行历史清理
// - Stop 时优雅关闭:池 drain + ticker.Stop + wg.Wait
//
-// 不引入 cron 库;清理调度通过"每小时检查时间"实现,足够 MVP。
+// 历史清理与日聚合维护不再由 runner 负责,由 OpsCleanupService 的统一 cron
+// 在凌晨触发 ChannelMonitorService.RunDailyMaintenance(复用 leader lock + heartbeat)。
//
// 定时任务维护:删除/创建/编辑 monitor 无需显式 reload,每个 tick 都会重新查 DB
// (ListEnabled + listDueForCheck),新 monitor 的 LastCheckedAt 为 nil 天然立即到期,
@@ -35,10 +35,6 @@ type ChannelMonitorRunner struct {
// 防止单次检测耗时 > interval 时同一 monitor 被并发执行。
inFlight map[int64]struct{}
inFlightMu sync.Mutex
-
- // 清理状态:lastCleanupDay 记录上次清理的"年-月-日",避免同一天重复跑。
- lastCleanupDay string
- cleanupMu sync.Mutex
}
// NewChannelMonitorRunner 构造调度器。Start 在 wire 中调用。
@@ -52,7 +48,7 @@ func NewChannelMonitorRunner(svc *ChannelMonitorService, settingService *Setting
}
}
-// Start 启动 ticker + worker pool + cleanup loop。
+// Start 启动 ticker + worker pool。
// 调用方需保证只调一次(wire ProvideChannelMonitorRunner 内只调一次)。
func (r *ChannelMonitorRunner) Start() {
if r == nil || r.svc == nil {
@@ -61,12 +57,11 @@ func (r *ChannelMonitorRunner) Start() {
// 容量 5 的 pond 池:超出时调用方等待,避免调度堆积无限增长。
r.pool = pond.NewPool(monitorWorkerConcurrency)
- r.wg.Add(2)
+ r.wg.Add(1)
go r.dueCheckLoop()
- go r.cleanupLoop()
}
-// Stop 优雅停止:close stopCh -> 等待两个 loop 退出 -> 池 drain。
+// Stop 优雅停止:close stopCh -> 等待 loop 退出 -> 池 drain。
func (r *ChannelMonitorRunner) Stop() {
if r == nil {
return
@@ -176,45 +171,3 @@ func (r *ChannelMonitorRunner) runOne(id int64, name string) {
"monitor_id", id, "name", name, "error", err)
}
}
-
-// cleanupLoop 每小时检查当前时间,到 monitorCleanupHour 点(且当天还没清理过)则跑一次清理。
-// 启动时立即检查一次,避免长时间运行才跑首次清理。
-func (r *ChannelMonitorRunner) cleanupLoop() {
- defer r.wg.Done()
-
- ticker := time.NewTicker(monitorCleanupCheckInterval)
- defer ticker.Stop()
-
- r.maybeRunCleanup()
- for {
- select {
- case <-r.stopCh:
- return
- case <-ticker.C:
- r.maybeRunCleanup()
- }
- }
-}
-
-// maybeRunCleanup 如果当前小时是 monitorCleanupHour 且当天未跑过,则执行清理。
-func (r *ChannelMonitorRunner) maybeRunCleanup() {
- now := time.Now()
- if now.Hour() != monitorCleanupHour {
- return
- }
- day := now.Format(monitorCleanupDayLayout)
-
- r.cleanupMu.Lock()
- if r.lastCleanupDay == day {
- r.cleanupMu.Unlock()
- return
- }
- r.lastCleanupDay = day
- r.cleanupMu.Unlock()
-
- ctx, cancel := context.WithTimeout(context.Background(), monitorCleanupTimeout)
- defer cancel()
- if err := r.svc.cleanupOldHistory(ctx); err != nil {
- slog.Warn("channel_monitor: cleanup history failed", "error", err)
- }
-}
diff --git a/backend/internal/service/channel_monitor_service.go b/backend/internal/service/channel_monitor_service.go
index 957ace15..144c66a0 100644
--- a/backend/internal/service/channel_monitor_service.go
+++ b/backend/internal/service/channel_monitor_service.go
@@ -41,6 +41,20 @@ type ChannelMonitorRepository interface {
// ListRecentHistoryForMonitors 批量取多个 monitor 各自主模型(primaryModels[monitorID])最近 perMonitorLimit 条历史。
// 返回的 entry 已按 checked_at DESC 排序(最新在前),不含 message 字段。
ListRecentHistoryForMonitors(ctx context.Context, ids []int64, primaryModels map[int64]string, perMonitorLimit int) (map[int64][]*ChannelMonitorHistoryEntry, error)
+
+ // ---------- 聚合维护(OpsCleanupService 调用) ----------
+
+ // UpsertDailyRollupsFor 把 targetDate 当天的明细按 (monitor_id, model, bucket_date)
+ // 聚合到 channel_monitor_daily_rollups。targetDate 会被截断到日期;
+ // 用 ON CONFLICT DO UPDATE 实现幂等回填,返回 upsert 影响的行数。
+ UpsertDailyRollupsFor(ctx context.Context, targetDate time.Time) (int64, error)
+ // DeleteRollupsBefore 软删 bucket_date < beforeDate 的聚合行,返回删除行数。
+ DeleteRollupsBefore(ctx context.Context, beforeDate time.Time) (int64, error)
+ // LoadAggregationWatermark 读 watermark(id=1)。
+ // 返回 nil 表示从未聚合过;watermark 表本身预期已存在单行(migration 110 写入)。
+ LoadAggregationWatermark(ctx context.Context) (*time.Time, error)
+ // UpdateAggregationWatermark 写 watermark(UPSERT 到 id=1)。
+ UpdateAggregationWatermark(ctx context.Context, date time.Time) error
}
// ChannelMonitorService 渠道监控管理服务。
@@ -300,9 +314,10 @@ func (s *ChannelMonitorService) listDueForCheck(ctx context.Context) ([]*Channel
return due, nil
}
-// cleanupOldHistory 删除 monitorHistoryRetentionDays 天之前的历史记录。
+// cleanupOldHistory 删除 monitorHistoryRetentionDays 天之前的明细历史记录。
+// 由 RunDailyMaintenance 调用;SoftDeleteMixin 自动把 DELETE 改为 UPDATE deleted_at。
func (s *ChannelMonitorService) cleanupOldHistory(ctx context.Context) error {
- before := time.Now().AddDate(0, 0, -monitorHistoryRetentionDays)
+ before := time.Now().UTC().AddDate(0, 0, -monitorHistoryRetentionDays)
deleted, err := s.repo.DeleteHistoryBefore(ctx, before)
if err != nil {
return fmt.Errorf("delete history before %s: %w", before.Format(time.RFC3339), err)
@@ -314,6 +329,94 @@ func (s *ChannelMonitorService) cleanupOldHistory(ctx context.Context) error {
return nil
}
+// RunDailyMaintenance 每日维护任务:聚合昨天之前未聚合的明细,软删过期明细和聚合。
+// 由 OpsCleanupService 的 cron 调度触发(共享 schedule 和 leader lock)。
+//
+// 幂等性:
+// - watermark 保证已聚合的日期不会重复处理;
+// - UpsertDailyRollupsFor 内部使用 ON CONFLICT DO UPDATE,同一日重复跑结果一致。
+//
+// 每一步失败都只记 slog.Warn,整体函数始终返回 nil 让后续步骤能继续跑
+// (与 OpsCleanupService.runCleanupOnce 风格一致)。
+func (s *ChannelMonitorService) RunDailyMaintenance(ctx context.Context) error {
+ now := time.Now().UTC()
+ today := now.Truncate(24 * time.Hour)
+
+ if err := s.runDailyAggregation(ctx, today); err != nil {
+ slog.Warn("channel_monitor: maintenance step failed",
+ "step", "aggregate", "error", err)
+ }
+ if err := s.cleanupOldHistory(ctx); err != nil {
+ slog.Warn("channel_monitor: maintenance step failed",
+ "step", "prune_history", "error", err)
+ }
+ if err := s.cleanupOldRollups(ctx, today); err != nil {
+ slog.Warn("channel_monitor: maintenance step failed",
+ "step", "prune_rollups", "error", err)
+ }
+ return nil
+}
+
+// runDailyAggregation 从 watermark+1 聚合到昨天(UTC)。
+// 首次跑(watermark nil):从 today-monitorRollupRetentionDays 开始回填。
+// 每次最多聚合 monitorMaintenanceMaxDaysPerRun 天,避免长事务。
+func (s *ChannelMonitorService) runDailyAggregation(ctx context.Context, today time.Time) error {
+ watermark, err := s.repo.LoadAggregationWatermark(ctx)
+ if err != nil {
+ return fmt.Errorf("load watermark: %w", err)
+ }
+
+ start := s.resolveAggregationStart(watermark, today)
+ if !start.Before(today) {
+ return nil // 没有需要聚合的日期
+ }
+
+ iterations := 0
+ for d := start; d.Before(today); d = d.Add(24 * time.Hour) {
+ if iterations >= monitorMaintenanceMaxDaysPerRun {
+ slog.Info("channel_monitor: maintenance aggregation capped",
+ "max_days", monitorMaintenanceMaxDaysPerRun,
+ "next_resume", d.Format("2006-01-02"))
+ break
+ }
+ affected, upErr := s.repo.UpsertDailyRollupsFor(ctx, d)
+ if upErr != nil {
+ return fmt.Errorf("upsert rollups for %s: %w", d.Format("2006-01-02"), upErr)
+ }
+ if err := s.repo.UpdateAggregationWatermark(ctx, d); err != nil {
+ return fmt.Errorf("update watermark to %s: %w", d.Format("2006-01-02"), err)
+ }
+ slog.Info("channel_monitor: rollups upserted",
+ "date", d.Format("2006-01-02"), "affected_rows", affected)
+ iterations++
+ }
+ return nil
+}
+
+// resolveAggregationStart 计算本次聚合起点:
+// - watermark == nil:today - monitorRollupRetentionDays(首次回填最多 30 天)
+// - watermark != nil:*watermark + 1 day
+func (s *ChannelMonitorService) resolveAggregationStart(watermark *time.Time, today time.Time) time.Time {
+ if watermark == nil {
+ return today.AddDate(0, 0, -monitorRollupRetentionDays)
+ }
+ return watermark.UTC().Truncate(24 * time.Hour).Add(24 * time.Hour)
+}
+
+// cleanupOldRollups 软删 bucket_date < today - monitorRollupRetentionDays 的日聚合行。
+func (s *ChannelMonitorService) cleanupOldRollups(ctx context.Context, today time.Time) error {
+ cutoff := today.AddDate(0, 0, -monitorRollupRetentionDays)
+ deleted, err := s.repo.DeleteRollupsBefore(ctx, cutoff)
+ if err != nil {
+ return fmt.Errorf("delete rollups before %s: %w", cutoff.Format("2006-01-02"), err)
+ }
+ if deleted > 0 {
+ slog.Info("channel_monitor: rollups cleanup",
+ "deleted_rows", deleted, "before", cutoff.Format("2006-01-02"))
+ }
+ return nil
+}
+
// ---------- helpers ----------
// decryptInPlace 把 ChannelMonitor.APIKey 从密文解密为明文。
diff --git a/backend/internal/service/ops_cleanup_service.go b/backend/internal/service/ops_cleanup_service.go
index 1cae6fe5..08a10a02 100644
--- a/backend/internal/service/ops_cleanup_service.go
+++ b/backend/internal/service/ops_cleanup_service.go
@@ -36,11 +36,15 @@ return 0
// - Scheduling: 5-field cron spec (minute hour dom month dow).
// - Multi-instance: best-effort Redis leader lock so only one node runs cleanup.
// - Safety: deletes in batches to avoid long transactions.
+//
+// 附带:在 runCleanupOnce 末尾调用 ChannelMonitorService.RunDailyMaintenance,
+// 统一共享 cron schedule + leader lock + heartbeat,避免再引一套调度。
type OpsCleanupService struct {
- opsRepo OpsRepository
- db *sql.DB
- redisClient *redis.Client
- cfg *config.Config
+ opsRepo OpsRepository
+ db *sql.DB
+ redisClient *redis.Client
+ cfg *config.Config
+ channelMonitorSvc *ChannelMonitorService
instanceID string
@@ -57,13 +61,15 @@ func NewOpsCleanupService(
db *sql.DB,
redisClient *redis.Client,
cfg *config.Config,
+ channelMonitorSvc *ChannelMonitorService,
) *OpsCleanupService {
return &OpsCleanupService{
- opsRepo: opsRepo,
- db: db,
- redisClient: redisClient,
- cfg: cfg,
- instanceID: uuid.NewString(),
+ opsRepo: opsRepo,
+ db: db,
+ redisClient: redisClient,
+ cfg: cfg,
+ channelMonitorSvc: channelMonitorSvc,
+ instanceID: uuid.NewString(),
}
}
@@ -248,6 +254,15 @@ func (s *OpsCleanupService) runCleanupOnce(ctx context.Context) (opsCleanupDelet
out.dailyPreagg = n
}
+ // Channel monitor 每日维护(聚合昨日明细 + 软删过期明细/聚合)。
+ // 失败只记日志,不影响 ops 清理的成功状态(与 ops 各步骤风格一致);
+ // 维护本身已经把每步错误打到 slog,heartbeat result 不再分项记录。
+ if s.channelMonitorSvc != nil {
+ if err := s.channelMonitorSvc.RunDailyMaintenance(ctx); err != nil {
+ logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] channel monitor maintenance failed: %v", err)
+ }
+ }
+
return out, nil
}
diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go
index 5d8d88d2..1482d650 100644
--- a/backend/internal/service/wire.go
+++ b/backend/internal/service/wire.go
@@ -262,13 +262,16 @@ func ProvideOpsAlertEvaluatorService(
}
// ProvideOpsCleanupService creates and starts OpsCleanupService (cron scheduled).
+// channelMonitorSvc 让维护任务(聚合 + 历史/聚合软删)跟随 ops 清理 cron 一起跑,
+// 共享 leader lock + heartbeat。
func ProvideOpsCleanupService(
opsRepo OpsRepository,
db *sql.DB,
redisClient *redis.Client,
cfg *config.Config,
+ channelMonitorSvc *ChannelMonitorService,
) *OpsCleanupService {
- svc := NewOpsCleanupService(opsRepo, db, redisClient, cfg)
+ svc := NewOpsCleanupService(opsRepo, db, redisClient, cfg, channelMonitorSvc)
svc.Start()
return svc
}
diff --git a/backend/migrations/126_add_channel_monitor_aggregation.sql b/backend/migrations/126_add_channel_monitor_aggregation.sql
new file mode 100644
index 00000000..e643763c
--- /dev/null
+++ b/backend/migrations/126_add_channel_monitor_aggregation.sql
@@ -0,0 +1,60 @@
+-- Migration: 126_add_channel_monitor_aggregation
+-- 渠道监控日聚合:把 channel_monitor_histories 的明细按天聚合,明细只保留 1 天,
+-- 聚合保留 30 天。明细和聚合表都用软删除(deleted_at),由 ops cleanup 任务每天
+-- 凌晨随运维监控清理一起跑(共享 cron)。
+--
+-- 设计要点:
+-- - channel_monitor_histories 加 deleted_at 软删除字段(SoftDeleteMixin 全局
+-- Hook 会把 DELETE 自动改写成 UPDATE deleted_at = NOW())。
+-- - channel_monitor_daily_rollups 按 (monitor_id, model, bucket_date) 唯一,
+-- 用 ON CONFLICT DO UPDATE 实现幂等回填,状态分布和延迟分子分母都保留,
+-- 方便后续按窗口任意求加权可用率和均值。
+-- - watermark 表只有一行(id=1),记录最近一次聚合到达的日期,避免重启后重复
+-- 扫全表。
+-- - rollup 上 (bucket_date) 索引服务清理任务的 DELETE WHERE bucket_date < cutoff。
+
+-- 1) 给历史明细表加软删除字段
+ALTER TABLE channel_monitor_histories
+ ADD COLUMN IF NOT EXISTS deleted_at TIMESTAMPTZ;
+
+CREATE INDEX IF NOT EXISTS idx_channel_monitor_histories_deleted_at
+ ON channel_monitor_histories (deleted_at);
+
+-- 2) 创建日聚合表
+CREATE TABLE IF NOT EXISTS channel_monitor_daily_rollups (
+ id BIGSERIAL PRIMARY KEY,
+ monitor_id BIGINT NOT NULL REFERENCES channel_monitors(id) ON DELETE CASCADE,
+ model VARCHAR(200) NOT NULL,
+ bucket_date DATE NOT NULL,
+ total_checks INT NOT NULL DEFAULT 0,
+ ok_count INT NOT NULL DEFAULT 0,
+ operational_count INT NOT NULL DEFAULT 0,
+ degraded_count INT NOT NULL DEFAULT 0,
+ failed_count INT NOT NULL DEFAULT 0,
+ error_count INT NOT NULL DEFAULT 0,
+ sum_latency_ms BIGINT NOT NULL DEFAULT 0,
+ count_latency INT NOT NULL DEFAULT 0,
+ sum_ping_latency_ms BIGINT NOT NULL DEFAULT 0,
+ count_ping_latency INT NOT NULL DEFAULT 0,
+ computed_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ deleted_at TIMESTAMPTZ
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS idx_channel_monitor_daily_rollups_unique
+ ON channel_monitor_daily_rollups (monitor_id, model, bucket_date);
+CREATE INDEX IF NOT EXISTS idx_channel_monitor_daily_rollups_bucket
+ ON channel_monitor_daily_rollups (bucket_date);
+CREATE INDEX IF NOT EXISTS idx_channel_monitor_daily_rollups_deleted_at
+ ON channel_monitor_daily_rollups (deleted_at);
+
+-- 3) 创建 watermark 表(单行:id=1)
+CREATE TABLE IF NOT EXISTS channel_monitor_aggregation_watermark (
+ id INT PRIMARY KEY DEFAULT 1,
+ last_aggregated_date DATE,
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ CONSTRAINT channel_monitor_aggregation_watermark_singleton CHECK (id = 1)
+);
+
+INSERT INTO channel_monitor_aggregation_watermark (id, last_aggregated_date, updated_at)
+VALUES (1, NULL, NOW())
+ON CONFLICT (id) DO NOTHING;
diff --git a/frontend/src/components/admin/monitor/MonitorFormDialog.vue b/frontend/src/components/admin/monitor/MonitorFormDialog.vue
index 56a06a9f..836ec079 100644
--- a/frontend/src/components/admin/monitor/MonitorFormDialog.vue
+++ b/frontend/src/components/admin/monitor/MonitorFormDialog.vue
@@ -113,6 +113,7 @@
:loading="myKeysLoading"
:keys="myActiveKeys"
:provider="form.provider"
+ :user-group-rates="userGroupRates"
@close="showKeyPicker = false"
@pick="pickMyKey"
/>
@@ -125,6 +126,7 @@ import { useAppStore } from '@/stores/app'
import { extractApiErrorMessage } from '@/utils/apiError'
import { adminAPI } from '@/api/admin'
import { keysAPI } from '@/api/keys'
+import { userGroupsAPI } from '@/api/groups'
import type {
ChannelMonitor,
CreateParams,
@@ -175,6 +177,7 @@ const submitting = ref(false)
const showKeyPicker = ref(false)
const myKeysLoading = ref(false)
const myActiveKeys = ref([])
+const userGroupRates = ref>({})
interface MonitorForm {
name: string
@@ -263,7 +266,10 @@ async function openMyKeyPicker() {
if (myActiveKeys.value.length > 0) return
myKeysLoading.value = true
try {
- const res = await keysAPI.list(1, 100, { status: 'active' })
+ const [res, rates] = await Promise.all([
+ keysAPI.list(1, 100, { status: 'active' }),
+ userGroupsAPI.getUserGroupRates(),
+ ])
const items = res.items || []
const now = Date.now()
myActiveKeys.value = items.filter(k => {
@@ -271,6 +277,7 @@ async function openMyKeyPicker() {
if (!k.expires_at) return true
return new Date(k.expires_at).getTime() > now
})
+ userGroupRates.value = rates
} catch (err: unknown) {
appStore.showError(extractApiErrorMessage(err, t('admin.channelMonitor.form.noActiveKey')))
} finally {
diff --git a/frontend/src/components/admin/monitor/MonitorKeyPickerDialog.vue b/frontend/src/components/admin/monitor/MonitorKeyPickerDialog.vue
index 4fd71cb2..8df8d586 100644
--- a/frontend/src/components/admin/monitor/MonitorKeyPickerDialog.vue
+++ b/frontend/src/components/admin/monitor/MonitorKeyPickerDialog.vue
@@ -47,9 +47,14 @@
{{ k.name }}
{{ maskApiKey(k.key) }}
-
- {{ k.group.name }}
-
+
—
@@ -73,14 +78,18 @@ import { useI18n } from 'vue-i18n'
import type { ApiKey } from '@/types'
import type { Provider } from '@/api/admin/channelMonitor'
import BaseDialog from '@/components/common/BaseDialog.vue'
+import GroupBadge from '@/components/common/GroupBadge.vue'
import { maskApiKey } from '@/utils/maskApiKey'
-const props = defineProps<{
+const props = withDefaults(defineProps<{
show: boolean
loading: boolean
keys: ApiKey[]
provider: Provider
-}>()
+ userGroupRates?: Record
+}>(), {
+ userGroupRates: () => ({}),
+})
defineEmits<{
(e: 'close'): void
--
GitLab
From ef6ec8a15a7d87f8c7f2cbf98c4c784b4b8b63b7 Mon Sep 17 00:00:00 2001
From: erio
Date: Tue, 21 Apr 2026 10:45:30 +0800
Subject: [PATCH 089/261] fix(channel-monitor): drop soft delete, refactor
feature flag to declarative form
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
### 后端修复:日志表不该用软删除
channel_monitor_histories / channel_monitor_daily_rollups 都是日志/聚合表,
没有恢复需求。110 里加的 SoftDeleteMixin 会让 DELETE 自动变成 UPDATE deleted_at,
导致行和索引只增不减,徒增磁盘占用和查询成本。
改回分批物理删(参考 OpsCleanupService.deleteOldRowsByID 模板):
- ent schema 移除 SoftDeleteMixin,重新 go generate
- repo 新增 deleteChannelMonitorBatched 辅助 + 两条 prune SQL 常量
(WITH batch AS SELECT id LIMIT 5000 → DELETE IN batch)
- DeleteHistoryBefore / DeleteRollupsBefore 改调分批 raw SQL
- 移除 ComputeAvailability / ComputeAvailabilityForMonitors / UpsertDailyRollupsFor /
ListLatestPerModel / ListLatestForMonitorIDs / ListRecentHistoryForMonitors 等
raw SQL 中的 deleted_at IS NULL 过滤
- UpsertDailyRollupsFor 的 ON CONFLICT 去掉 deleted_at = NULL 重置
- migration 111 DROP COLUMN deleted_at + 对应索引(110 已部署但 maintenance
首跑在次日 02:00,此时尚无业务数据在依赖软删除)
### 前端重构:feature flag 声明式 + 复用
AppSidebar.vue 里 7 处 `...(flag ? [item] : [])` 样板代码删光,改为 NavItem 加
featureFlag?: () => boolean | undefined 字段,加一个 applyFeatureFlags 递归
过滤(含 children)。语义统一为 `!== false`(宽容策略,undefined 时默认显示,
避免 public settings 未加载完成时菜单闪烁消失 — 对应用户反馈"刷新后菜单消失
要去保存设置才回来")。
- 集中声明 4 个 flag getter:flagChannelMonitor / flagPayment /
flagOpsMonitoring / flagAdminPayment
- 提取 buildSelfNavItems 复用用户端主菜单和管理员"我的账户"子菜单
- 未来新增开关:在统一位置加一个 flag getter + 给对应 NavItem 加字段
(不用再动渲染逻辑)
bump 0.1.114.29
---
backend/ent/channelmonitordailyrollup.go | 16 +-
.../channelmonitordailyrollup.go | 16 --
.../ent/channelmonitordailyrollup/where.go | 55 ------
.../ent/channelmonitordailyrollup_create.go | 94 +---------
.../ent/channelmonitordailyrollup_query.go | 8 +-
.../ent/channelmonitordailyrollup_update.go | 72 +-------
backend/ent/channelmonitorhistory.go | 16 +-
.../channelmonitorhistory.go | 16 --
backend/ent/channelmonitorhistory/where.go | 55 ------
backend/ent/channelmonitorhistory_create.go | 94 +---------
backend/ent/channelmonitorhistory_query.go | 8 +-
backend/ent/channelmonitorhistory_update.go | 52 ------
backend/ent/client.go | 12 +-
backend/ent/migrate/schema.go | 14 +-
backend/ent/mutation.go | 155 +---------------
backend/ent/runtime/runtime.go | 10 -
.../schema/channel_monitor_daily_rollup.go | 9 +-
backend/ent/schema/channel_monitor_history.go | 13 +-
.../repository/channel_monitor_repo.go | 88 +++++----
.../127_drop_channel_monitor_deleted_at.sql | 16 ++
frontend/src/components/layout/AppSidebar.vue | 171 ++++++++----------
21 files changed, 188 insertions(+), 802 deletions(-)
create mode 100644 backend/migrations/127_drop_channel_monitor_deleted_at.sql
diff --git a/backend/ent/channelmonitordailyrollup.go b/backend/ent/channelmonitordailyrollup.go
index 6c7a8afa..78a5f489 100644
--- a/backend/ent/channelmonitordailyrollup.go
+++ b/backend/ent/channelmonitordailyrollup.go
@@ -18,8 +18,6 @@ type ChannelMonitorDailyRollup struct {
config `json:"-"`
// ID of the ent.
ID int64 `json:"id,omitempty"`
- // DeletedAt holds the value of the "deleted_at" field.
- DeletedAt *time.Time `json:"deleted_at,omitempty"`
// MonitorID holds the value of the "monitor_id" field.
MonitorID int64 `json:"monitor_id,omitempty"`
// Model holds the value of the "model" field.
@@ -83,7 +81,7 @@ func (*ChannelMonitorDailyRollup) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullInt64)
case channelmonitordailyrollup.FieldModel:
values[i] = new(sql.NullString)
- case channelmonitordailyrollup.FieldDeletedAt, channelmonitordailyrollup.FieldBucketDate, channelmonitordailyrollup.FieldComputedAt:
+ case channelmonitordailyrollup.FieldBucketDate, channelmonitordailyrollup.FieldComputedAt:
values[i] = new(sql.NullTime)
default:
values[i] = new(sql.UnknownType)
@@ -106,13 +104,6 @@ func (_m *ChannelMonitorDailyRollup) assignValues(columns []string, values []any
return fmt.Errorf("unexpected type %T for field id", value)
}
_m.ID = int64(value.Int64)
- case channelmonitordailyrollup.FieldDeletedAt:
- if value, ok := values[i].(*sql.NullTime); !ok {
- return fmt.Errorf("unexpected type %T for field deleted_at", values[i])
- } else if value.Valid {
- _m.DeletedAt = new(time.Time)
- *_m.DeletedAt = value.Time
- }
case channelmonitordailyrollup.FieldMonitorID:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field monitor_id", values[i])
@@ -238,11 +229,6 @@ func (_m *ChannelMonitorDailyRollup) String() string {
var builder strings.Builder
builder.WriteString("ChannelMonitorDailyRollup(")
builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
- if v := _m.DeletedAt; v != nil {
- builder.WriteString("deleted_at=")
- builder.WriteString(v.Format(time.ANSIC))
- }
- builder.WriteString(", ")
builder.WriteString("monitor_id=")
builder.WriteString(fmt.Sprintf("%v", _m.MonitorID))
builder.WriteString(", ")
diff --git a/backend/ent/channelmonitordailyrollup/channelmonitordailyrollup.go b/backend/ent/channelmonitordailyrollup/channelmonitordailyrollup.go
index eb1f69a8..e7cb9307 100644
--- a/backend/ent/channelmonitordailyrollup/channelmonitordailyrollup.go
+++ b/backend/ent/channelmonitordailyrollup/channelmonitordailyrollup.go
@@ -5,7 +5,6 @@ package channelmonitordailyrollup
import (
"time"
- "entgo.io/ent"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
)
@@ -15,8 +14,6 @@ const (
Label = "channel_monitor_daily_rollup"
// FieldID holds the string denoting the id field in the database.
FieldID = "id"
- // FieldDeletedAt holds the string denoting the deleted_at field in the database.
- FieldDeletedAt = "deleted_at"
// FieldMonitorID holds the string denoting the monitor_id field in the database.
FieldMonitorID = "monitor_id"
// FieldModel holds the string denoting the model field in the database.
@@ -61,7 +58,6 @@ const (
// Columns holds all SQL columns for channelmonitordailyrollup fields.
var Columns = []string{
FieldID,
- FieldDeletedAt,
FieldMonitorID,
FieldModel,
FieldBucketDate,
@@ -88,14 +84,7 @@ func ValidColumn(column string) bool {
return false
}
-// Note that the variables below are initialized by the runtime
-// package on the initialization of the application. Therefore,
-// it should be imported in the main as follows:
-//
-// import _ "github.com/Wei-Shaw/sub2api/ent/runtime"
var (
- Hooks [1]ent.Hook
- Interceptors [1]ent.Interceptor
// ModelValidator is a validator for the "model" field. It is called by the builders before save.
ModelValidator func(string) error
// DefaultTotalChecks holds the default value on creation for the "total_checks" field.
@@ -132,11 +121,6 @@ func ByID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldID, opts...).ToFunc()
}
-// ByDeletedAt orders the results by the deleted_at field.
-func ByDeletedAt(opts ...sql.OrderTermOption) OrderOption {
- return sql.OrderByField(FieldDeletedAt, opts...).ToFunc()
-}
-
// ByMonitorID orders the results by the monitor_id field.
func ByMonitorID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldMonitorID, opts...).ToFunc()
diff --git a/backend/ent/channelmonitordailyrollup/where.go b/backend/ent/channelmonitordailyrollup/where.go
index 9da8d4be..424c957e 100644
--- a/backend/ent/channelmonitordailyrollup/where.go
+++ b/backend/ent/channelmonitordailyrollup/where.go
@@ -55,11 +55,6 @@ func IDLTE(id int64) predicate.ChannelMonitorDailyRollup {
return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldID, id))
}
-// DeletedAt applies equality check predicate on the "deleted_at" field. It's identical to DeletedAtEQ.
-func DeletedAt(v time.Time) predicate.ChannelMonitorDailyRollup {
- return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldDeletedAt, v))
-}
-
// MonitorID applies equality check predicate on the "monitor_id" field. It's identical to MonitorIDEQ.
func MonitorID(v int64) predicate.ChannelMonitorDailyRollup {
return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldMonitorID, v))
@@ -130,56 +125,6 @@ func ComputedAt(v time.Time) predicate.ChannelMonitorDailyRollup {
return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldComputedAt, v))
}
-// DeletedAtEQ applies the EQ predicate on the "deleted_at" field.
-func DeletedAtEQ(v time.Time) predicate.ChannelMonitorDailyRollup {
- return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldDeletedAt, v))
-}
-
-// DeletedAtNEQ applies the NEQ predicate on the "deleted_at" field.
-func DeletedAtNEQ(v time.Time) predicate.ChannelMonitorDailyRollup {
- return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldDeletedAt, v))
-}
-
-// DeletedAtIn applies the In predicate on the "deleted_at" field.
-func DeletedAtIn(vs ...time.Time) predicate.ChannelMonitorDailyRollup {
- return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldDeletedAt, vs...))
-}
-
-// DeletedAtNotIn applies the NotIn predicate on the "deleted_at" field.
-func DeletedAtNotIn(vs ...time.Time) predicate.ChannelMonitorDailyRollup {
- return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldDeletedAt, vs...))
-}
-
-// DeletedAtGT applies the GT predicate on the "deleted_at" field.
-func DeletedAtGT(v time.Time) predicate.ChannelMonitorDailyRollup {
- return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldDeletedAt, v))
-}
-
-// DeletedAtGTE applies the GTE predicate on the "deleted_at" field.
-func DeletedAtGTE(v time.Time) predicate.ChannelMonitorDailyRollup {
- return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldDeletedAt, v))
-}
-
-// DeletedAtLT applies the LT predicate on the "deleted_at" field.
-func DeletedAtLT(v time.Time) predicate.ChannelMonitorDailyRollup {
- return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldDeletedAt, v))
-}
-
-// DeletedAtLTE applies the LTE predicate on the "deleted_at" field.
-func DeletedAtLTE(v time.Time) predicate.ChannelMonitorDailyRollup {
- return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldDeletedAt, v))
-}
-
-// DeletedAtIsNil applies the IsNil predicate on the "deleted_at" field.
-func DeletedAtIsNil() predicate.ChannelMonitorDailyRollup {
- return predicate.ChannelMonitorDailyRollup(sql.FieldIsNull(FieldDeletedAt))
-}
-
-// DeletedAtNotNil applies the NotNil predicate on the "deleted_at" field.
-func DeletedAtNotNil() predicate.ChannelMonitorDailyRollup {
- return predicate.ChannelMonitorDailyRollup(sql.FieldNotNull(FieldDeletedAt))
-}
-
// MonitorIDEQ applies the EQ predicate on the "monitor_id" field.
func MonitorIDEQ(v int64) predicate.ChannelMonitorDailyRollup {
return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldMonitorID, v))
diff --git a/backend/ent/channelmonitordailyrollup_create.go b/backend/ent/channelmonitordailyrollup_create.go
index c4850751..5f8754ba 100644
--- a/backend/ent/channelmonitordailyrollup_create.go
+++ b/backend/ent/channelmonitordailyrollup_create.go
@@ -23,20 +23,6 @@ type ChannelMonitorDailyRollupCreate struct {
conflict []sql.ConflictOption
}
-// SetDeletedAt sets the "deleted_at" field.
-func (_c *ChannelMonitorDailyRollupCreate) SetDeletedAt(v time.Time) *ChannelMonitorDailyRollupCreate {
- _c.mutation.SetDeletedAt(v)
- return _c
-}
-
-// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil.
-func (_c *ChannelMonitorDailyRollupCreate) SetNillableDeletedAt(v *time.Time) *ChannelMonitorDailyRollupCreate {
- if v != nil {
- _c.SetDeletedAt(*v)
- }
- return _c
-}
-
// SetMonitorID sets the "monitor_id" field.
func (_c *ChannelMonitorDailyRollupCreate) SetMonitorID(v int64) *ChannelMonitorDailyRollupCreate {
_c.mutation.SetMonitorID(v)
@@ -221,9 +207,7 @@ func (_c *ChannelMonitorDailyRollupCreate) Mutation() *ChannelMonitorDailyRollup
// Save creates the ChannelMonitorDailyRollup in the database.
func (_c *ChannelMonitorDailyRollupCreate) Save(ctx context.Context) (*ChannelMonitorDailyRollup, error) {
- if err := _c.defaults(); err != nil {
- return nil, err
- }
+ _c.defaults()
return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
}
@@ -250,7 +234,7 @@ func (_c *ChannelMonitorDailyRollupCreate) ExecX(ctx context.Context) {
}
// defaults sets the default values of the builder before save.
-func (_c *ChannelMonitorDailyRollupCreate) defaults() error {
+func (_c *ChannelMonitorDailyRollupCreate) defaults() {
if _, ok := _c.mutation.TotalChecks(); !ok {
v := channelmonitordailyrollup.DefaultTotalChecks
_c.mutation.SetTotalChecks(v)
@@ -292,13 +276,9 @@ func (_c *ChannelMonitorDailyRollupCreate) defaults() error {
_c.mutation.SetCountPingLatency(v)
}
if _, ok := _c.mutation.ComputedAt(); !ok {
- if channelmonitordailyrollup.DefaultComputedAt == nil {
- return fmt.Errorf("ent: uninitialized channelmonitordailyrollup.DefaultComputedAt (forgotten import ent/runtime?)")
- }
v := channelmonitordailyrollup.DefaultComputedAt()
_c.mutation.SetComputedAt(v)
}
- return nil
}
// check runs all checks and user-defined validators on the builder.
@@ -380,10 +360,6 @@ func (_c *ChannelMonitorDailyRollupCreate) createSpec() (*ChannelMonitorDailyRol
_spec = sqlgraph.NewCreateSpec(channelmonitordailyrollup.Table, sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64))
)
_spec.OnConflict = _c.conflict
- if value, ok := _c.mutation.DeletedAt(); ok {
- _spec.SetField(channelmonitordailyrollup.FieldDeletedAt, field.TypeTime, value)
- _node.DeletedAt = &value
- }
if value, ok := _c.mutation.Model(); ok {
_spec.SetField(channelmonitordailyrollup.FieldModel, field.TypeString, value)
_node.Model = value
@@ -460,7 +436,7 @@ func (_c *ChannelMonitorDailyRollupCreate) createSpec() (*ChannelMonitorDailyRol
// of the `INSERT` statement. For example:
//
// client.ChannelMonitorDailyRollup.Create().
-// SetDeletedAt(v).
+// SetMonitorID(v).
// OnConflict(
// // Update the row with the new values
// // the was proposed for insertion.
@@ -469,7 +445,7 @@ func (_c *ChannelMonitorDailyRollupCreate) createSpec() (*ChannelMonitorDailyRol
// // Override some of the fields with custom
// // update values.
// Update(func(u *ent.ChannelMonitorDailyRollupUpsert) {
-// SetDeletedAt(v+v).
+// SetMonitorID(v+v).
// }).
// Exec(ctx)
func (_c *ChannelMonitorDailyRollupCreate) OnConflict(opts ...sql.ConflictOption) *ChannelMonitorDailyRollupUpsertOne {
@@ -505,24 +481,6 @@ type (
}
)
-// SetDeletedAt sets the "deleted_at" field.
-func (u *ChannelMonitorDailyRollupUpsert) SetDeletedAt(v time.Time) *ChannelMonitorDailyRollupUpsert {
- u.Set(channelmonitordailyrollup.FieldDeletedAt, v)
- return u
-}
-
-// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create.
-func (u *ChannelMonitorDailyRollupUpsert) UpdateDeletedAt() *ChannelMonitorDailyRollupUpsert {
- u.SetExcluded(channelmonitordailyrollup.FieldDeletedAt)
- return u
-}
-
-// ClearDeletedAt clears the value of the "deleted_at" field.
-func (u *ChannelMonitorDailyRollupUpsert) ClearDeletedAt() *ChannelMonitorDailyRollupUpsert {
- u.SetNull(channelmonitordailyrollup.FieldDeletedAt)
- return u
-}
-
// SetMonitorID sets the "monitor_id" field.
func (u *ChannelMonitorDailyRollupUpsert) SetMonitorID(v int64) *ChannelMonitorDailyRollupUpsert {
u.Set(channelmonitordailyrollup.FieldMonitorID, v)
@@ -791,27 +749,6 @@ func (u *ChannelMonitorDailyRollupUpsertOne) Update(set func(*ChannelMonitorDail
return u
}
-// SetDeletedAt sets the "deleted_at" field.
-func (u *ChannelMonitorDailyRollupUpsertOne) SetDeletedAt(v time.Time) *ChannelMonitorDailyRollupUpsertOne {
- return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
- s.SetDeletedAt(v)
- })
-}
-
-// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create.
-func (u *ChannelMonitorDailyRollupUpsertOne) UpdateDeletedAt() *ChannelMonitorDailyRollupUpsertOne {
- return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
- s.UpdateDeletedAt()
- })
-}
-
-// ClearDeletedAt clears the value of the "deleted_at" field.
-func (u *ChannelMonitorDailyRollupUpsertOne) ClearDeletedAt() *ChannelMonitorDailyRollupUpsertOne {
- return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
- s.ClearDeletedAt()
- })
-}
-
// SetMonitorID sets the "monitor_id" field.
func (u *ChannelMonitorDailyRollupUpsertOne) SetMonitorID(v int64) *ChannelMonitorDailyRollupUpsertOne {
return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
@@ -1213,7 +1150,7 @@ func (_c *ChannelMonitorDailyRollupCreateBulk) ExecX(ctx context.Context) {
// // Override some of the fields with custom
// // update values.
// Update(func(u *ent.ChannelMonitorDailyRollupUpsert) {
-// SetDeletedAt(v+v).
+// SetMonitorID(v+v).
// }).
// Exec(ctx)
func (_c *ChannelMonitorDailyRollupCreateBulk) OnConflict(opts ...sql.ConflictOption) *ChannelMonitorDailyRollupUpsertBulk {
@@ -1282,27 +1219,6 @@ func (u *ChannelMonitorDailyRollupUpsertBulk) Update(set func(*ChannelMonitorDai
return u
}
-// SetDeletedAt sets the "deleted_at" field.
-func (u *ChannelMonitorDailyRollupUpsertBulk) SetDeletedAt(v time.Time) *ChannelMonitorDailyRollupUpsertBulk {
- return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
- s.SetDeletedAt(v)
- })
-}
-
-// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create.
-func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateDeletedAt() *ChannelMonitorDailyRollupUpsertBulk {
- return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
- s.UpdateDeletedAt()
- })
-}
-
-// ClearDeletedAt clears the value of the "deleted_at" field.
-func (u *ChannelMonitorDailyRollupUpsertBulk) ClearDeletedAt() *ChannelMonitorDailyRollupUpsertBulk {
- return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
- s.ClearDeletedAt()
- })
-}
-
// SetMonitorID sets the "monitor_id" field.
func (u *ChannelMonitorDailyRollupUpsertBulk) SetMonitorID(v int64) *ChannelMonitorDailyRollupUpsertBulk {
return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
diff --git a/backend/ent/channelmonitordailyrollup_query.go b/backend/ent/channelmonitordailyrollup_query.go
index 30528575..e34afc61 100644
--- a/backend/ent/channelmonitordailyrollup_query.go
+++ b/backend/ent/channelmonitordailyrollup_query.go
@@ -300,12 +300,12 @@ func (_q *ChannelMonitorDailyRollupQuery) WithMonitor(opts ...func(*ChannelMonit
// Example:
//
// var v []struct {
-// DeletedAt time.Time `json:"deleted_at,omitempty"`
+// MonitorID int64 `json:"monitor_id,omitempty"`
// Count int `json:"count,omitempty"`
// }
//
// client.ChannelMonitorDailyRollup.Query().
-// GroupBy(channelmonitordailyrollup.FieldDeletedAt).
+// GroupBy(channelmonitordailyrollup.FieldMonitorID).
// Aggregate(ent.Count()).
// Scan(ctx, &v)
func (_q *ChannelMonitorDailyRollupQuery) GroupBy(field string, fields ...string) *ChannelMonitorDailyRollupGroupBy {
@@ -323,11 +323,11 @@ func (_q *ChannelMonitorDailyRollupQuery) GroupBy(field string, fields ...string
// Example:
//
// var v []struct {
-// DeletedAt time.Time `json:"deleted_at,omitempty"`
+// MonitorID int64 `json:"monitor_id,omitempty"`
// }
//
// client.ChannelMonitorDailyRollup.Query().
-// Select(channelmonitordailyrollup.FieldDeletedAt).
+// Select(channelmonitordailyrollup.FieldMonitorID).
// Scan(ctx, &v)
func (_q *ChannelMonitorDailyRollupQuery) Select(fields ...string) *ChannelMonitorDailyRollupSelect {
_q.ctx.Fields = append(_q.ctx.Fields, fields...)
diff --git a/backend/ent/channelmonitordailyrollup_update.go b/backend/ent/channelmonitordailyrollup_update.go
index 0b82f8bf..02cd86c5 100644
--- a/backend/ent/channelmonitordailyrollup_update.go
+++ b/backend/ent/channelmonitordailyrollup_update.go
@@ -29,26 +29,6 @@ func (_u *ChannelMonitorDailyRollupUpdate) Where(ps ...predicate.ChannelMonitorD
return _u
}
-// SetDeletedAt sets the "deleted_at" field.
-func (_u *ChannelMonitorDailyRollupUpdate) SetDeletedAt(v time.Time) *ChannelMonitorDailyRollupUpdate {
- _u.mutation.SetDeletedAt(v)
- return _u
-}
-
-// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil.
-func (_u *ChannelMonitorDailyRollupUpdate) SetNillableDeletedAt(v *time.Time) *ChannelMonitorDailyRollupUpdate {
- if v != nil {
- _u.SetDeletedAt(*v)
- }
- return _u
-}
-
-// ClearDeletedAt clears the value of the "deleted_at" field.
-func (_u *ChannelMonitorDailyRollupUpdate) ClearDeletedAt() *ChannelMonitorDailyRollupUpdate {
- _u.mutation.ClearDeletedAt()
- return _u
-}
-
// SetMonitorID sets the "monitor_id" field.
func (_u *ChannelMonitorDailyRollupUpdate) SetMonitorID(v int64) *ChannelMonitorDailyRollupUpdate {
_u.mutation.SetMonitorID(v)
@@ -325,9 +305,7 @@ func (_u *ChannelMonitorDailyRollupUpdate) ClearMonitor() *ChannelMonitorDailyRo
// Save executes the query and returns the number of nodes affected by the update operation.
func (_u *ChannelMonitorDailyRollupUpdate) Save(ctx context.Context) (int, error) {
- if err := _u.defaults(); err != nil {
- return 0, err
- }
+ _u.defaults()
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
@@ -354,15 +332,11 @@ func (_u *ChannelMonitorDailyRollupUpdate) ExecX(ctx context.Context) {
}
// defaults sets the default values of the builder before save.
-func (_u *ChannelMonitorDailyRollupUpdate) defaults() error {
+func (_u *ChannelMonitorDailyRollupUpdate) defaults() {
if _, ok := _u.mutation.ComputedAt(); !ok {
- if channelmonitordailyrollup.UpdateDefaultComputedAt == nil {
- return fmt.Errorf("ent: uninitialized channelmonitordailyrollup.UpdateDefaultComputedAt (forgotten import ent/runtime?)")
- }
v := channelmonitordailyrollup.UpdateDefaultComputedAt()
_u.mutation.SetComputedAt(v)
}
- return nil
}
// check runs all checks and user-defined validators on the builder.
@@ -390,12 +364,6 @@ func (_u *ChannelMonitorDailyRollupUpdate) sqlSave(ctx context.Context) (_node i
}
}
}
- if value, ok := _u.mutation.DeletedAt(); ok {
- _spec.SetField(channelmonitordailyrollup.FieldDeletedAt, field.TypeTime, value)
- }
- if _u.mutation.DeletedAtCleared() {
- _spec.ClearField(channelmonitordailyrollup.FieldDeletedAt, field.TypeTime)
- }
if value, ok := _u.mutation.Model(); ok {
_spec.SetField(channelmonitordailyrollup.FieldModel, field.TypeString, value)
}
@@ -514,26 +482,6 @@ type ChannelMonitorDailyRollupUpdateOne struct {
mutation *ChannelMonitorDailyRollupMutation
}
-// SetDeletedAt sets the "deleted_at" field.
-func (_u *ChannelMonitorDailyRollupUpdateOne) SetDeletedAt(v time.Time) *ChannelMonitorDailyRollupUpdateOne {
- _u.mutation.SetDeletedAt(v)
- return _u
-}
-
-// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil.
-func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableDeletedAt(v *time.Time) *ChannelMonitorDailyRollupUpdateOne {
- if v != nil {
- _u.SetDeletedAt(*v)
- }
- return _u
-}
-
-// ClearDeletedAt clears the value of the "deleted_at" field.
-func (_u *ChannelMonitorDailyRollupUpdateOne) ClearDeletedAt() *ChannelMonitorDailyRollupUpdateOne {
- _u.mutation.ClearDeletedAt()
- return _u
-}
-
// SetMonitorID sets the "monitor_id" field.
func (_u *ChannelMonitorDailyRollupUpdateOne) SetMonitorID(v int64) *ChannelMonitorDailyRollupUpdateOne {
_u.mutation.SetMonitorID(v)
@@ -823,9 +771,7 @@ func (_u *ChannelMonitorDailyRollupUpdateOne) Select(field string, fields ...str
// Save executes the query and returns the updated ChannelMonitorDailyRollup entity.
func (_u *ChannelMonitorDailyRollupUpdateOne) Save(ctx context.Context) (*ChannelMonitorDailyRollup, error) {
- if err := _u.defaults(); err != nil {
- return nil, err
- }
+ _u.defaults()
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
@@ -852,15 +798,11 @@ func (_u *ChannelMonitorDailyRollupUpdateOne) ExecX(ctx context.Context) {
}
// defaults sets the default values of the builder before save.
-func (_u *ChannelMonitorDailyRollupUpdateOne) defaults() error {
+func (_u *ChannelMonitorDailyRollupUpdateOne) defaults() {
if _, ok := _u.mutation.ComputedAt(); !ok {
- if channelmonitordailyrollup.UpdateDefaultComputedAt == nil {
- return fmt.Errorf("ent: uninitialized channelmonitordailyrollup.UpdateDefaultComputedAt (forgotten import ent/runtime?)")
- }
v := channelmonitordailyrollup.UpdateDefaultComputedAt()
_u.mutation.SetComputedAt(v)
}
- return nil
}
// check runs all checks and user-defined validators on the builder.
@@ -905,12 +847,6 @@ func (_u *ChannelMonitorDailyRollupUpdateOne) sqlSave(ctx context.Context) (_nod
}
}
}
- if value, ok := _u.mutation.DeletedAt(); ok {
- _spec.SetField(channelmonitordailyrollup.FieldDeletedAt, field.TypeTime, value)
- }
- if _u.mutation.DeletedAtCleared() {
- _spec.ClearField(channelmonitordailyrollup.FieldDeletedAt, field.TypeTime)
- }
if value, ok := _u.mutation.Model(); ok {
_spec.SetField(channelmonitordailyrollup.FieldModel, field.TypeString, value)
}
diff --git a/backend/ent/channelmonitorhistory.go b/backend/ent/channelmonitorhistory.go
index 256eaf5f..70dde542 100644
--- a/backend/ent/channelmonitorhistory.go
+++ b/backend/ent/channelmonitorhistory.go
@@ -18,8 +18,6 @@ type ChannelMonitorHistory struct {
config `json:"-"`
// ID of the ent.
ID int64 `json:"id,omitempty"`
- // DeletedAt holds the value of the "deleted_at" field.
- DeletedAt *time.Time `json:"deleted_at,omitempty"`
// MonitorID holds the value of the "monitor_id" field.
MonitorID int64 `json:"monitor_id,omitempty"`
// Model holds the value of the "model" field.
@@ -69,7 +67,7 @@ func (*ChannelMonitorHistory) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullInt64)
case channelmonitorhistory.FieldModel, channelmonitorhistory.FieldStatus, channelmonitorhistory.FieldMessage:
values[i] = new(sql.NullString)
- case channelmonitorhistory.FieldDeletedAt, channelmonitorhistory.FieldCheckedAt:
+ case channelmonitorhistory.FieldCheckedAt:
values[i] = new(sql.NullTime)
default:
values[i] = new(sql.UnknownType)
@@ -92,13 +90,6 @@ func (_m *ChannelMonitorHistory) assignValues(columns []string, values []any) er
return fmt.Errorf("unexpected type %T for field id", value)
}
_m.ID = int64(value.Int64)
- case channelmonitorhistory.FieldDeletedAt:
- if value, ok := values[i].(*sql.NullTime); !ok {
- return fmt.Errorf("unexpected type %T for field deleted_at", values[i])
- } else if value.Valid {
- _m.DeletedAt = new(time.Time)
- *_m.DeletedAt = value.Time
- }
case channelmonitorhistory.FieldMonitorID:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field monitor_id", values[i])
@@ -184,11 +175,6 @@ func (_m *ChannelMonitorHistory) String() string {
var builder strings.Builder
builder.WriteString("ChannelMonitorHistory(")
builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
- if v := _m.DeletedAt; v != nil {
- builder.WriteString("deleted_at=")
- builder.WriteString(v.Format(time.ANSIC))
- }
- builder.WriteString(", ")
builder.WriteString("monitor_id=")
builder.WriteString(fmt.Sprintf("%v", _m.MonitorID))
builder.WriteString(", ")
diff --git a/backend/ent/channelmonitorhistory/channelmonitorhistory.go b/backend/ent/channelmonitorhistory/channelmonitorhistory.go
index da59791b..6a9dc006 100644
--- a/backend/ent/channelmonitorhistory/channelmonitorhistory.go
+++ b/backend/ent/channelmonitorhistory/channelmonitorhistory.go
@@ -6,7 +6,6 @@ import (
"fmt"
"time"
- "entgo.io/ent"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
)
@@ -16,8 +15,6 @@ const (
Label = "channel_monitor_history"
// FieldID holds the string denoting the id field in the database.
FieldID = "id"
- // FieldDeletedAt holds the string denoting the deleted_at field in the database.
- FieldDeletedAt = "deleted_at"
// FieldMonitorID holds the string denoting the monitor_id field in the database.
FieldMonitorID = "monitor_id"
// FieldModel holds the string denoting the model field in the database.
@@ -48,7 +45,6 @@ const (
// Columns holds all SQL columns for channelmonitorhistory fields.
var Columns = []string{
FieldID,
- FieldDeletedAt,
FieldMonitorID,
FieldModel,
FieldStatus,
@@ -68,14 +64,7 @@ func ValidColumn(column string) bool {
return false
}
-// Note that the variables below are initialized by the runtime
-// package on the initialization of the application. Therefore,
-// it should be imported in the main as follows:
-//
-// import _ "github.com/Wei-Shaw/sub2api/ent/runtime"
var (
- Hooks [1]ent.Hook
- Interceptors [1]ent.Interceptor
// ModelValidator is a validator for the "model" field. It is called by the builders before save.
ModelValidator func(string) error
// DefaultMessage holds the default value on creation for the "message" field.
@@ -119,11 +108,6 @@ func ByID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldID, opts...).ToFunc()
}
-// ByDeletedAt orders the results by the deleted_at field.
-func ByDeletedAt(opts ...sql.OrderTermOption) OrderOption {
- return sql.OrderByField(FieldDeletedAt, opts...).ToFunc()
-}
-
// ByMonitorID orders the results by the monitor_id field.
func ByMonitorID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldMonitorID, opts...).ToFunc()
diff --git a/backend/ent/channelmonitorhistory/where.go b/backend/ent/channelmonitorhistory/where.go
index 7b1cd50d..afa73f35 100644
--- a/backend/ent/channelmonitorhistory/where.go
+++ b/backend/ent/channelmonitorhistory/where.go
@@ -55,11 +55,6 @@ func IDLTE(id int64) predicate.ChannelMonitorHistory {
return predicate.ChannelMonitorHistory(sql.FieldLTE(FieldID, id))
}
-// DeletedAt applies equality check predicate on the "deleted_at" field. It's identical to DeletedAtEQ.
-func DeletedAt(v time.Time) predicate.ChannelMonitorHistory {
- return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldDeletedAt, v))
-}
-
// MonitorID applies equality check predicate on the "monitor_id" field. It's identical to MonitorIDEQ.
func MonitorID(v int64) predicate.ChannelMonitorHistory {
return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldMonitorID, v))
@@ -90,56 +85,6 @@ func CheckedAt(v time.Time) predicate.ChannelMonitorHistory {
return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldCheckedAt, v))
}
-// DeletedAtEQ applies the EQ predicate on the "deleted_at" field.
-func DeletedAtEQ(v time.Time) predicate.ChannelMonitorHistory {
- return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldDeletedAt, v))
-}
-
-// DeletedAtNEQ applies the NEQ predicate on the "deleted_at" field.
-func DeletedAtNEQ(v time.Time) predicate.ChannelMonitorHistory {
- return predicate.ChannelMonitorHistory(sql.FieldNEQ(FieldDeletedAt, v))
-}
-
-// DeletedAtIn applies the In predicate on the "deleted_at" field.
-func DeletedAtIn(vs ...time.Time) predicate.ChannelMonitorHistory {
- return predicate.ChannelMonitorHistory(sql.FieldIn(FieldDeletedAt, vs...))
-}
-
-// DeletedAtNotIn applies the NotIn predicate on the "deleted_at" field.
-func DeletedAtNotIn(vs ...time.Time) predicate.ChannelMonitorHistory {
- return predicate.ChannelMonitorHistory(sql.FieldNotIn(FieldDeletedAt, vs...))
-}
-
-// DeletedAtGT applies the GT predicate on the "deleted_at" field.
-func DeletedAtGT(v time.Time) predicate.ChannelMonitorHistory {
- return predicate.ChannelMonitorHistory(sql.FieldGT(FieldDeletedAt, v))
-}
-
-// DeletedAtGTE applies the GTE predicate on the "deleted_at" field.
-func DeletedAtGTE(v time.Time) predicate.ChannelMonitorHistory {
- return predicate.ChannelMonitorHistory(sql.FieldGTE(FieldDeletedAt, v))
-}
-
-// DeletedAtLT applies the LT predicate on the "deleted_at" field.
-func DeletedAtLT(v time.Time) predicate.ChannelMonitorHistory {
- return predicate.ChannelMonitorHistory(sql.FieldLT(FieldDeletedAt, v))
-}
-
-// DeletedAtLTE applies the LTE predicate on the "deleted_at" field.
-func DeletedAtLTE(v time.Time) predicate.ChannelMonitorHistory {
- return predicate.ChannelMonitorHistory(sql.FieldLTE(FieldDeletedAt, v))
-}
-
-// DeletedAtIsNil applies the IsNil predicate on the "deleted_at" field.
-func DeletedAtIsNil() predicate.ChannelMonitorHistory {
- return predicate.ChannelMonitorHistory(sql.FieldIsNull(FieldDeletedAt))
-}
-
-// DeletedAtNotNil applies the NotNil predicate on the "deleted_at" field.
-func DeletedAtNotNil() predicate.ChannelMonitorHistory {
- return predicate.ChannelMonitorHistory(sql.FieldNotNull(FieldDeletedAt))
-}
-
// MonitorIDEQ applies the EQ predicate on the "monitor_id" field.
func MonitorIDEQ(v int64) predicate.ChannelMonitorHistory {
return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldMonitorID, v))
diff --git a/backend/ent/channelmonitorhistory_create.go b/backend/ent/channelmonitorhistory_create.go
index 9a68c9ce..71034865 100644
--- a/backend/ent/channelmonitorhistory_create.go
+++ b/backend/ent/channelmonitorhistory_create.go
@@ -23,20 +23,6 @@ type ChannelMonitorHistoryCreate struct {
conflict []sql.ConflictOption
}
-// SetDeletedAt sets the "deleted_at" field.
-func (_c *ChannelMonitorHistoryCreate) SetDeletedAt(v time.Time) *ChannelMonitorHistoryCreate {
- _c.mutation.SetDeletedAt(v)
- return _c
-}
-
-// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil.
-func (_c *ChannelMonitorHistoryCreate) SetNillableDeletedAt(v *time.Time) *ChannelMonitorHistoryCreate {
- if v != nil {
- _c.SetDeletedAt(*v)
- }
- return _c
-}
-
// SetMonitorID sets the "monitor_id" field.
func (_c *ChannelMonitorHistoryCreate) SetMonitorID(v int64) *ChannelMonitorHistoryCreate {
_c.mutation.SetMonitorID(v)
@@ -123,9 +109,7 @@ func (_c *ChannelMonitorHistoryCreate) Mutation() *ChannelMonitorHistoryMutation
// Save creates the ChannelMonitorHistory in the database.
func (_c *ChannelMonitorHistoryCreate) Save(ctx context.Context) (*ChannelMonitorHistory, error) {
- if err := _c.defaults(); err != nil {
- return nil, err
- }
+ _c.defaults()
return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
}
@@ -152,19 +136,15 @@ func (_c *ChannelMonitorHistoryCreate) ExecX(ctx context.Context) {
}
// defaults sets the default values of the builder before save.
-func (_c *ChannelMonitorHistoryCreate) defaults() error {
+func (_c *ChannelMonitorHistoryCreate) defaults() {
if _, ok := _c.mutation.Message(); !ok {
v := channelmonitorhistory.DefaultMessage
_c.mutation.SetMessage(v)
}
if _, ok := _c.mutation.CheckedAt(); !ok {
- if channelmonitorhistory.DefaultCheckedAt == nil {
- return fmt.Errorf("ent: uninitialized channelmonitorhistory.DefaultCheckedAt (forgotten import ent/runtime?)")
- }
v := channelmonitorhistory.DefaultCheckedAt()
_c.mutation.SetCheckedAt(v)
}
- return nil
}
// check runs all checks and user-defined validators on the builder.
@@ -226,10 +206,6 @@ func (_c *ChannelMonitorHistoryCreate) createSpec() (*ChannelMonitorHistory, *sq
_spec = sqlgraph.NewCreateSpec(channelmonitorhistory.Table, sqlgraph.NewFieldSpec(channelmonitorhistory.FieldID, field.TypeInt64))
)
_spec.OnConflict = _c.conflict
- if value, ok := _c.mutation.DeletedAt(); ok {
- _spec.SetField(channelmonitorhistory.FieldDeletedAt, field.TypeTime, value)
- _node.DeletedAt = &value
- }
if value, ok := _c.mutation.Model(); ok {
_spec.SetField(channelmonitorhistory.FieldModel, field.TypeString, value)
_node.Model = value
@@ -278,7 +254,7 @@ func (_c *ChannelMonitorHistoryCreate) createSpec() (*ChannelMonitorHistory, *sq
// of the `INSERT` statement. For example:
//
// client.ChannelMonitorHistory.Create().
-// SetDeletedAt(v).
+// SetMonitorID(v).
// OnConflict(
// // Update the row with the new values
// // the was proposed for insertion.
@@ -287,7 +263,7 @@ func (_c *ChannelMonitorHistoryCreate) createSpec() (*ChannelMonitorHistory, *sq
// // Override some of the fields with custom
// // update values.
// Update(func(u *ent.ChannelMonitorHistoryUpsert) {
-// SetDeletedAt(v+v).
+// SetMonitorID(v+v).
// }).
// Exec(ctx)
func (_c *ChannelMonitorHistoryCreate) OnConflict(opts ...sql.ConflictOption) *ChannelMonitorHistoryUpsertOne {
@@ -323,24 +299,6 @@ type (
}
)
-// SetDeletedAt sets the "deleted_at" field.
-func (u *ChannelMonitorHistoryUpsert) SetDeletedAt(v time.Time) *ChannelMonitorHistoryUpsert {
- u.Set(channelmonitorhistory.FieldDeletedAt, v)
- return u
-}
-
-// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create.
-func (u *ChannelMonitorHistoryUpsert) UpdateDeletedAt() *ChannelMonitorHistoryUpsert {
- u.SetExcluded(channelmonitorhistory.FieldDeletedAt)
- return u
-}
-
-// ClearDeletedAt clears the value of the "deleted_at" field.
-func (u *ChannelMonitorHistoryUpsert) ClearDeletedAt() *ChannelMonitorHistoryUpsert {
- u.SetNull(channelmonitorhistory.FieldDeletedAt)
- return u
-}
-
// SetMonitorID sets the "monitor_id" field.
func (u *ChannelMonitorHistoryUpsert) SetMonitorID(v int64) *ChannelMonitorHistoryUpsert {
u.Set(channelmonitorhistory.FieldMonitorID, v)
@@ -495,27 +453,6 @@ func (u *ChannelMonitorHistoryUpsertOne) Update(set func(*ChannelMonitorHistoryU
return u
}
-// SetDeletedAt sets the "deleted_at" field.
-func (u *ChannelMonitorHistoryUpsertOne) SetDeletedAt(v time.Time) *ChannelMonitorHistoryUpsertOne {
- return u.Update(func(s *ChannelMonitorHistoryUpsert) {
- s.SetDeletedAt(v)
- })
-}
-
-// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create.
-func (u *ChannelMonitorHistoryUpsertOne) UpdateDeletedAt() *ChannelMonitorHistoryUpsertOne {
- return u.Update(func(s *ChannelMonitorHistoryUpsert) {
- s.UpdateDeletedAt()
- })
-}
-
-// ClearDeletedAt clears the value of the "deleted_at" field.
-func (u *ChannelMonitorHistoryUpsertOne) ClearDeletedAt() *ChannelMonitorHistoryUpsertOne {
- return u.Update(func(s *ChannelMonitorHistoryUpsert) {
- s.ClearDeletedAt()
- })
-}
-
// SetMonitorID sets the "monitor_id" field.
func (u *ChannelMonitorHistoryUpsertOne) SetMonitorID(v int64) *ChannelMonitorHistoryUpsertOne {
return u.Update(func(s *ChannelMonitorHistoryUpsert) {
@@ -784,7 +721,7 @@ func (_c *ChannelMonitorHistoryCreateBulk) ExecX(ctx context.Context) {
// // Override some of the fields with custom
// // update values.
// Update(func(u *ent.ChannelMonitorHistoryUpsert) {
-// SetDeletedAt(v+v).
+// SetMonitorID(v+v).
// }).
// Exec(ctx)
func (_c *ChannelMonitorHistoryCreateBulk) OnConflict(opts ...sql.ConflictOption) *ChannelMonitorHistoryUpsertBulk {
@@ -853,27 +790,6 @@ func (u *ChannelMonitorHistoryUpsertBulk) Update(set func(*ChannelMonitorHistory
return u
}
-// SetDeletedAt sets the "deleted_at" field.
-func (u *ChannelMonitorHistoryUpsertBulk) SetDeletedAt(v time.Time) *ChannelMonitorHistoryUpsertBulk {
- return u.Update(func(s *ChannelMonitorHistoryUpsert) {
- s.SetDeletedAt(v)
- })
-}
-
-// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create.
-func (u *ChannelMonitorHistoryUpsertBulk) UpdateDeletedAt() *ChannelMonitorHistoryUpsertBulk {
- return u.Update(func(s *ChannelMonitorHistoryUpsert) {
- s.UpdateDeletedAt()
- })
-}
-
-// ClearDeletedAt clears the value of the "deleted_at" field.
-func (u *ChannelMonitorHistoryUpsertBulk) ClearDeletedAt() *ChannelMonitorHistoryUpsertBulk {
- return u.Update(func(s *ChannelMonitorHistoryUpsert) {
- s.ClearDeletedAt()
- })
-}
-
// SetMonitorID sets the "monitor_id" field.
func (u *ChannelMonitorHistoryUpsertBulk) SetMonitorID(v int64) *ChannelMonitorHistoryUpsertBulk {
return u.Update(func(s *ChannelMonitorHistoryUpsert) {
diff --git a/backend/ent/channelmonitorhistory_query.go b/backend/ent/channelmonitorhistory_query.go
index 26a1528f..1fb872ad 100644
--- a/backend/ent/channelmonitorhistory_query.go
+++ b/backend/ent/channelmonitorhistory_query.go
@@ -300,12 +300,12 @@ func (_q *ChannelMonitorHistoryQuery) WithMonitor(opts ...func(*ChannelMonitorQu
// Example:
//
// var v []struct {
-// DeletedAt time.Time `json:"deleted_at,omitempty"`
+// MonitorID int64 `json:"monitor_id,omitempty"`
// Count int `json:"count,omitempty"`
// }
//
// client.ChannelMonitorHistory.Query().
-// GroupBy(channelmonitorhistory.FieldDeletedAt).
+// GroupBy(channelmonitorhistory.FieldMonitorID).
// Aggregate(ent.Count()).
// Scan(ctx, &v)
func (_q *ChannelMonitorHistoryQuery) GroupBy(field string, fields ...string) *ChannelMonitorHistoryGroupBy {
@@ -323,11 +323,11 @@ func (_q *ChannelMonitorHistoryQuery) GroupBy(field string, fields ...string) *C
// Example:
//
// var v []struct {
-// DeletedAt time.Time `json:"deleted_at,omitempty"`
+// MonitorID int64 `json:"monitor_id,omitempty"`
// }
//
// client.ChannelMonitorHistory.Query().
-// Select(channelmonitorhistory.FieldDeletedAt).
+// Select(channelmonitorhistory.FieldMonitorID).
// Scan(ctx, &v)
func (_q *ChannelMonitorHistoryQuery) Select(fields ...string) *ChannelMonitorHistorySelect {
_q.ctx.Fields = append(_q.ctx.Fields, fields...)
diff --git a/backend/ent/channelmonitorhistory_update.go b/backend/ent/channelmonitorhistory_update.go
index 85193ec1..a85a8072 100644
--- a/backend/ent/channelmonitorhistory_update.go
+++ b/backend/ent/channelmonitorhistory_update.go
@@ -29,26 +29,6 @@ func (_u *ChannelMonitorHistoryUpdate) Where(ps ...predicate.ChannelMonitorHisto
return _u
}
-// SetDeletedAt sets the "deleted_at" field.
-func (_u *ChannelMonitorHistoryUpdate) SetDeletedAt(v time.Time) *ChannelMonitorHistoryUpdate {
- _u.mutation.SetDeletedAt(v)
- return _u
-}
-
-// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil.
-func (_u *ChannelMonitorHistoryUpdate) SetNillableDeletedAt(v *time.Time) *ChannelMonitorHistoryUpdate {
- if v != nil {
- _u.SetDeletedAt(*v)
- }
- return _u
-}
-
-// ClearDeletedAt clears the value of the "deleted_at" field.
-func (_u *ChannelMonitorHistoryUpdate) ClearDeletedAt() *ChannelMonitorHistoryUpdate {
- _u.mutation.ClearDeletedAt()
- return _u
-}
-
// SetMonitorID sets the "monitor_id" field.
func (_u *ChannelMonitorHistoryUpdate) SetMonitorID(v int64) *ChannelMonitorHistoryUpdate {
_u.mutation.SetMonitorID(v)
@@ -257,12 +237,6 @@ func (_u *ChannelMonitorHistoryUpdate) sqlSave(ctx context.Context) (_node int,
}
}
}
- if value, ok := _u.mutation.DeletedAt(); ok {
- _spec.SetField(channelmonitorhistory.FieldDeletedAt, field.TypeTime, value)
- }
- if _u.mutation.DeletedAtCleared() {
- _spec.ClearField(channelmonitorhistory.FieldDeletedAt, field.TypeTime)
- }
if value, ok := _u.mutation.Model(); ok {
_spec.SetField(channelmonitorhistory.FieldModel, field.TypeString, value)
}
@@ -345,26 +319,6 @@ type ChannelMonitorHistoryUpdateOne struct {
mutation *ChannelMonitorHistoryMutation
}
-// SetDeletedAt sets the "deleted_at" field.
-func (_u *ChannelMonitorHistoryUpdateOne) SetDeletedAt(v time.Time) *ChannelMonitorHistoryUpdateOne {
- _u.mutation.SetDeletedAt(v)
- return _u
-}
-
-// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil.
-func (_u *ChannelMonitorHistoryUpdateOne) SetNillableDeletedAt(v *time.Time) *ChannelMonitorHistoryUpdateOne {
- if v != nil {
- _u.SetDeletedAt(*v)
- }
- return _u
-}
-
-// ClearDeletedAt clears the value of the "deleted_at" field.
-func (_u *ChannelMonitorHistoryUpdateOne) ClearDeletedAt() *ChannelMonitorHistoryUpdateOne {
- _u.mutation.ClearDeletedAt()
- return _u
-}
-
// SetMonitorID sets the "monitor_id" field.
func (_u *ChannelMonitorHistoryUpdateOne) SetMonitorID(v int64) *ChannelMonitorHistoryUpdateOne {
_u.mutation.SetMonitorID(v)
@@ -603,12 +557,6 @@ func (_u *ChannelMonitorHistoryUpdateOne) sqlSave(ctx context.Context) (_node *C
}
}
}
- if value, ok := _u.mutation.DeletedAt(); ok {
- _spec.SetField(channelmonitorhistory.FieldDeletedAt, field.TypeTime, value)
- }
- if _u.mutation.DeletedAtCleared() {
- _spec.ClearField(channelmonitorhistory.FieldDeletedAt, field.TypeTime)
- }
if value, ok := _u.mutation.Model(); ok {
_spec.SetField(channelmonitorhistory.FieldModel, field.TypeString, value)
}
diff --git a/backend/ent/client.go b/backend/ent/client.go
index ca208094..ebc7fc5e 100644
--- a/backend/ent/client.go
+++ b/backend/ent/client.go
@@ -1912,14 +1912,12 @@ func (c *ChannelMonitorDailyRollupClient) QueryMonitor(_m *ChannelMonitorDailyRo
// Hooks returns the client hooks.
func (c *ChannelMonitorDailyRollupClient) Hooks() []Hook {
- hooks := c.hooks.ChannelMonitorDailyRollup
- return append(hooks[:len(hooks):len(hooks)], channelmonitordailyrollup.Hooks[:]...)
+ return c.hooks.ChannelMonitorDailyRollup
}
// Interceptors returns the client interceptors.
func (c *ChannelMonitorDailyRollupClient) Interceptors() []Interceptor {
- inters := c.inters.ChannelMonitorDailyRollup
- return append(inters[:len(inters):len(inters)], channelmonitordailyrollup.Interceptors[:]...)
+ return c.inters.ChannelMonitorDailyRollup
}
func (c *ChannelMonitorDailyRollupClient) mutate(ctx context.Context, m *ChannelMonitorDailyRollupMutation) (Value, error) {
@@ -2063,14 +2061,12 @@ func (c *ChannelMonitorHistoryClient) QueryMonitor(_m *ChannelMonitorHistory) *C
// Hooks returns the client hooks.
func (c *ChannelMonitorHistoryClient) Hooks() []Hook {
- hooks := c.hooks.ChannelMonitorHistory
- return append(hooks[:len(hooks):len(hooks)], channelmonitorhistory.Hooks[:]...)
+ return c.hooks.ChannelMonitorHistory
}
// Interceptors returns the client interceptors.
func (c *ChannelMonitorHistoryClient) Interceptors() []Interceptor {
- inters := c.inters.ChannelMonitorHistory
- return append(inters[:len(inters):len(inters)], channelmonitorhistory.Interceptors[:]...)
+ return c.inters.ChannelMonitorHistory
}
func (c *ChannelMonitorHistoryClient) mutate(ctx context.Context, m *ChannelMonitorHistoryMutation) (Value, error) {
diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go
index 9ce914a3..dba43ddf 100644
--- a/backend/ent/migrate/schema.go
+++ b/backend/ent/migrate/schema.go
@@ -464,7 +464,6 @@ var (
// ChannelMonitorDailyRollupsColumns holds the columns for the "channel_monitor_daily_rollups" table.
ChannelMonitorDailyRollupsColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
- {Name: "deleted_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "model", Type: field.TypeString, Size: 200},
{Name: "bucket_date", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "date"}},
{Name: "total_checks", Type: field.TypeInt, Default: 0},
@@ -488,7 +487,7 @@ var (
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "channel_monitor_daily_rollups_channel_monitors_daily_rollups",
- Columns: []*schema.Column{ChannelMonitorDailyRollupsColumns[15]},
+ Columns: []*schema.Column{ChannelMonitorDailyRollupsColumns[14]},
RefColumns: []*schema.Column{ChannelMonitorsColumns[0]},
OnDelete: schema.Cascade,
},
@@ -497,19 +496,18 @@ var (
{
Name: "channelmonitordailyrollup_monitor_id_model_bucket_date",
Unique: true,
- Columns: []*schema.Column{ChannelMonitorDailyRollupsColumns[15], ChannelMonitorDailyRollupsColumns[2], ChannelMonitorDailyRollupsColumns[3]},
+ Columns: []*schema.Column{ChannelMonitorDailyRollupsColumns[14], ChannelMonitorDailyRollupsColumns[1], ChannelMonitorDailyRollupsColumns[2]},
},
{
Name: "channelmonitordailyrollup_bucket_date",
Unique: false,
- Columns: []*schema.Column{ChannelMonitorDailyRollupsColumns[3]},
+ Columns: []*schema.Column{ChannelMonitorDailyRollupsColumns[2]},
},
},
}
// ChannelMonitorHistoriesColumns holds the columns for the "channel_monitor_histories" table.
ChannelMonitorHistoriesColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
- {Name: "deleted_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "model", Type: field.TypeString, Size: 200},
{Name: "status", Type: field.TypeEnum, Enums: []string{"operational", "degraded", "failed", "error"}},
{Name: "latency_ms", Type: field.TypeInt, Nullable: true},
@@ -526,7 +524,7 @@ var (
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "channel_monitor_histories_channel_monitors_history",
- Columns: []*schema.Column{ChannelMonitorHistoriesColumns[8]},
+ Columns: []*schema.Column{ChannelMonitorHistoriesColumns[7]},
RefColumns: []*schema.Column{ChannelMonitorsColumns[0]},
OnDelete: schema.Cascade,
},
@@ -535,12 +533,12 @@ var (
{
Name: "channelmonitorhistory_monitor_id_model_checked_at",
Unique: false,
- Columns: []*schema.Column{ChannelMonitorHistoriesColumns[8], ChannelMonitorHistoriesColumns[2], ChannelMonitorHistoriesColumns[7]},
+ Columns: []*schema.Column{ChannelMonitorHistoriesColumns[7], ChannelMonitorHistoriesColumns[1], ChannelMonitorHistoriesColumns[6]},
},
{
Name: "channelmonitorhistory_checked_at",
Unique: false,
- Columns: []*schema.Column{ChannelMonitorHistoriesColumns[7]},
+ Columns: []*schema.Column{ChannelMonitorHistoriesColumns[6]},
},
},
}
diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go
index e97456fe..43e52371 100644
--- a/backend/ent/mutation.go
+++ b/backend/ent/mutation.go
@@ -10022,7 +10022,6 @@ type ChannelMonitorDailyRollupMutation struct {
op Op
typ string
id *int64
- deleted_at *time.Time
model *string
bucket_date *time.Time
total_checks *int
@@ -10152,55 +10151,6 @@ func (m *ChannelMonitorDailyRollupMutation) IDs(ctx context.Context) ([]int64, e
}
}
-// SetDeletedAt sets the "deleted_at" field.
-func (m *ChannelMonitorDailyRollupMutation) SetDeletedAt(t time.Time) {
- m.deleted_at = &t
-}
-
-// DeletedAt returns the value of the "deleted_at" field in the mutation.
-func (m *ChannelMonitorDailyRollupMutation) DeletedAt() (r time.Time, exists bool) {
- v := m.deleted_at
- if v == nil {
- return
- }
- return *v, true
-}
-
-// OldDeletedAt returns the old "deleted_at" field's value of the ChannelMonitorDailyRollup entity.
-// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *ChannelMonitorDailyRollupMutation) OldDeletedAt(ctx context.Context) (v *time.Time, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldDeletedAt is only allowed on UpdateOne operations")
- }
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldDeletedAt requires an ID field in the mutation")
- }
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldDeletedAt: %w", err)
- }
- return oldValue.DeletedAt, nil
-}
-
-// ClearDeletedAt clears the value of the "deleted_at" field.
-func (m *ChannelMonitorDailyRollupMutation) ClearDeletedAt() {
- m.deleted_at = nil
- m.clearedFields[channelmonitordailyrollup.FieldDeletedAt] = struct{}{}
-}
-
-// DeletedAtCleared returns if the "deleted_at" field was cleared in this mutation.
-func (m *ChannelMonitorDailyRollupMutation) DeletedAtCleared() bool {
- _, ok := m.clearedFields[channelmonitordailyrollup.FieldDeletedAt]
- return ok
-}
-
-// ResetDeletedAt resets all changes to the "deleted_at" field.
-func (m *ChannelMonitorDailyRollupMutation) ResetDeletedAt() {
- m.deleted_at = nil
- delete(m.clearedFields, channelmonitordailyrollup.FieldDeletedAt)
-}
-
// SetMonitorID sets the "monitor_id" field.
func (m *ChannelMonitorDailyRollupMutation) SetMonitorID(i int64) {
m.monitor = &i
@@ -10966,10 +10916,7 @@ func (m *ChannelMonitorDailyRollupMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *ChannelMonitorDailyRollupMutation) Fields() []string {
- fields := make([]string, 0, 15)
- if m.deleted_at != nil {
- fields = append(fields, channelmonitordailyrollup.FieldDeletedAt)
- }
+ fields := make([]string, 0, 14)
if m.monitor != nil {
fields = append(fields, channelmonitordailyrollup.FieldMonitorID)
}
@@ -11020,8 +10967,6 @@ func (m *ChannelMonitorDailyRollupMutation) Fields() []string {
// schema.
func (m *ChannelMonitorDailyRollupMutation) Field(name string) (ent.Value, bool) {
switch name {
- case channelmonitordailyrollup.FieldDeletedAt:
- return m.DeletedAt()
case channelmonitordailyrollup.FieldMonitorID:
return m.MonitorID()
case channelmonitordailyrollup.FieldModel:
@@ -11059,8 +11004,6 @@ func (m *ChannelMonitorDailyRollupMutation) Field(name string) (ent.Value, bool)
// database failed.
func (m *ChannelMonitorDailyRollupMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
switch name {
- case channelmonitordailyrollup.FieldDeletedAt:
- return m.OldDeletedAt(ctx)
case channelmonitordailyrollup.FieldMonitorID:
return m.OldMonitorID(ctx)
case channelmonitordailyrollup.FieldModel:
@@ -11098,13 +11041,6 @@ func (m *ChannelMonitorDailyRollupMutation) OldField(ctx context.Context, name s
// type.
func (m *ChannelMonitorDailyRollupMutation) SetField(name string, value ent.Value) error {
switch name {
- case channelmonitordailyrollup.FieldDeletedAt:
- v, ok := value.(time.Time)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetDeletedAt(v)
- return nil
case channelmonitordailyrollup.FieldMonitorID:
v, ok := value.(int64)
if !ok {
@@ -11355,11 +11291,7 @@ func (m *ChannelMonitorDailyRollupMutation) AddField(name string, value ent.Valu
// ClearedFields returns all nullable fields that were cleared during this
// mutation.
func (m *ChannelMonitorDailyRollupMutation) ClearedFields() []string {
- var fields []string
- if m.FieldCleared(channelmonitordailyrollup.FieldDeletedAt) {
- fields = append(fields, channelmonitordailyrollup.FieldDeletedAt)
- }
- return fields
+ return nil
}
// FieldCleared returns a boolean indicating if a field with the given name was
@@ -11372,11 +11304,6 @@ func (m *ChannelMonitorDailyRollupMutation) FieldCleared(name string) bool {
// ClearField clears the value of the field with the given name. It returns an
// error if the field is not defined in the schema.
func (m *ChannelMonitorDailyRollupMutation) ClearField(name string) error {
- switch name {
- case channelmonitordailyrollup.FieldDeletedAt:
- m.ClearDeletedAt()
- return nil
- }
return fmt.Errorf("unknown ChannelMonitorDailyRollup nullable field %s", name)
}
@@ -11384,9 +11311,6 @@ func (m *ChannelMonitorDailyRollupMutation) ClearField(name string) error {
// It returns an error if the field is not defined in the schema.
func (m *ChannelMonitorDailyRollupMutation) ResetField(name string) error {
switch name {
- case channelmonitordailyrollup.FieldDeletedAt:
- m.ResetDeletedAt()
- return nil
case channelmonitordailyrollup.FieldMonitorID:
m.ResetMonitorID()
return nil
@@ -11513,7 +11437,6 @@ type ChannelMonitorHistoryMutation struct {
op Op
typ string
id *int64
- deleted_at *time.Time
model *string
status *channelmonitorhistory.Status
latency_ms *int
@@ -11628,55 +11551,6 @@ func (m *ChannelMonitorHistoryMutation) IDs(ctx context.Context) ([]int64, error
}
}
-// SetDeletedAt sets the "deleted_at" field.
-func (m *ChannelMonitorHistoryMutation) SetDeletedAt(t time.Time) {
- m.deleted_at = &t
-}
-
-// DeletedAt returns the value of the "deleted_at" field in the mutation.
-func (m *ChannelMonitorHistoryMutation) DeletedAt() (r time.Time, exists bool) {
- v := m.deleted_at
- if v == nil {
- return
- }
- return *v, true
-}
-
-// OldDeletedAt returns the old "deleted_at" field's value of the ChannelMonitorHistory entity.
-// If the ChannelMonitorHistory object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *ChannelMonitorHistoryMutation) OldDeletedAt(ctx context.Context) (v *time.Time, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldDeletedAt is only allowed on UpdateOne operations")
- }
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldDeletedAt requires an ID field in the mutation")
- }
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldDeletedAt: %w", err)
- }
- return oldValue.DeletedAt, nil
-}
-
-// ClearDeletedAt clears the value of the "deleted_at" field.
-func (m *ChannelMonitorHistoryMutation) ClearDeletedAt() {
- m.deleted_at = nil
- m.clearedFields[channelmonitorhistory.FieldDeletedAt] = struct{}{}
-}
-
-// DeletedAtCleared returns if the "deleted_at" field was cleared in this mutation.
-func (m *ChannelMonitorHistoryMutation) DeletedAtCleared() bool {
- _, ok := m.clearedFields[channelmonitorhistory.FieldDeletedAt]
- return ok
-}
-
-// ResetDeletedAt resets all changes to the "deleted_at" field.
-func (m *ChannelMonitorHistoryMutation) ResetDeletedAt() {
- m.deleted_at = nil
- delete(m.clearedFields, channelmonitorhistory.FieldDeletedAt)
-}
-
// SetMonitorID sets the "monitor_id" field.
func (m *ChannelMonitorHistoryMutation) SetMonitorID(i int64) {
m.monitor = &i
@@ -12071,10 +11945,7 @@ func (m *ChannelMonitorHistoryMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *ChannelMonitorHistoryMutation) Fields() []string {
- fields := make([]string, 0, 8)
- if m.deleted_at != nil {
- fields = append(fields, channelmonitorhistory.FieldDeletedAt)
- }
+ fields := make([]string, 0, 7)
if m.monitor != nil {
fields = append(fields, channelmonitorhistory.FieldMonitorID)
}
@@ -12104,8 +11975,6 @@ func (m *ChannelMonitorHistoryMutation) Fields() []string {
// schema.
func (m *ChannelMonitorHistoryMutation) Field(name string) (ent.Value, bool) {
switch name {
- case channelmonitorhistory.FieldDeletedAt:
- return m.DeletedAt()
case channelmonitorhistory.FieldMonitorID:
return m.MonitorID()
case channelmonitorhistory.FieldModel:
@@ -12129,8 +11998,6 @@ func (m *ChannelMonitorHistoryMutation) Field(name string) (ent.Value, bool) {
// database failed.
func (m *ChannelMonitorHistoryMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
switch name {
- case channelmonitorhistory.FieldDeletedAt:
- return m.OldDeletedAt(ctx)
case channelmonitorhistory.FieldMonitorID:
return m.OldMonitorID(ctx)
case channelmonitorhistory.FieldModel:
@@ -12154,13 +12021,6 @@ func (m *ChannelMonitorHistoryMutation) OldField(ctx context.Context, name strin
// type.
func (m *ChannelMonitorHistoryMutation) SetField(name string, value ent.Value) error {
switch name {
- case channelmonitorhistory.FieldDeletedAt:
- v, ok := value.(time.Time)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetDeletedAt(v)
- return nil
case channelmonitorhistory.FieldMonitorID:
v, ok := value.(int64)
if !ok {
@@ -12267,9 +12127,6 @@ func (m *ChannelMonitorHistoryMutation) AddField(name string, value ent.Value) e
// mutation.
func (m *ChannelMonitorHistoryMutation) ClearedFields() []string {
var fields []string
- if m.FieldCleared(channelmonitorhistory.FieldDeletedAt) {
- fields = append(fields, channelmonitorhistory.FieldDeletedAt)
- }
if m.FieldCleared(channelmonitorhistory.FieldLatencyMs) {
fields = append(fields, channelmonitorhistory.FieldLatencyMs)
}
@@ -12293,9 +12150,6 @@ func (m *ChannelMonitorHistoryMutation) FieldCleared(name string) bool {
// error if the field is not defined in the schema.
func (m *ChannelMonitorHistoryMutation) ClearField(name string) error {
switch name {
- case channelmonitorhistory.FieldDeletedAt:
- m.ClearDeletedAt()
- return nil
case channelmonitorhistory.FieldLatencyMs:
m.ClearLatencyMs()
return nil
@@ -12313,9 +12167,6 @@ func (m *ChannelMonitorHistoryMutation) ClearField(name string) error {
// It returns an error if the field is not defined in the schema.
func (m *ChannelMonitorHistoryMutation) ResetField(name string) error {
switch name {
- case channelmonitorhistory.FieldDeletedAt:
- m.ResetDeletedAt()
- return nil
case channelmonitorhistory.FieldMonitorID:
m.ResetMonitorID()
return nil
diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go
index 25076444..63552bb5 100644
--- a/backend/ent/runtime/runtime.go
+++ b/backend/ent/runtime/runtime.go
@@ -521,11 +521,6 @@ func init() {
channelmonitorDescIntervalSeconds := channelmonitorFields[8].Descriptor()
// channelmonitor.IntervalSecondsValidator is a validator for the "interval_seconds" field. It is called by the builders before save.
channelmonitor.IntervalSecondsValidator = channelmonitorDescIntervalSeconds.Validators[0].(func(int) error)
- channelmonitordailyrollupMixin := schema.ChannelMonitorDailyRollup{}.Mixin()
- channelmonitordailyrollupMixinHooks0 := channelmonitordailyrollupMixin[0].Hooks()
- channelmonitordailyrollup.Hooks[0] = channelmonitordailyrollupMixinHooks0[0]
- channelmonitordailyrollupMixinInters0 := channelmonitordailyrollupMixin[0].Interceptors()
- channelmonitordailyrollup.Interceptors[0] = channelmonitordailyrollupMixinInters0[0]
channelmonitordailyrollupFields := schema.ChannelMonitorDailyRollup{}.Fields()
_ = channelmonitordailyrollupFields
// channelmonitordailyrollupDescModel is the schema descriptor for model field.
@@ -592,11 +587,6 @@ func init() {
channelmonitordailyrollup.DefaultComputedAt = channelmonitordailyrollupDescComputedAt.Default.(func() time.Time)
// channelmonitordailyrollup.UpdateDefaultComputedAt holds the default value on update for the computed_at field.
channelmonitordailyrollup.UpdateDefaultComputedAt = channelmonitordailyrollupDescComputedAt.UpdateDefault.(func() time.Time)
- channelmonitorhistoryMixin := schema.ChannelMonitorHistory{}.Mixin()
- channelmonitorhistoryMixinHooks0 := channelmonitorhistoryMixin[0].Hooks()
- channelmonitorhistory.Hooks[0] = channelmonitorhistoryMixinHooks0[0]
- channelmonitorhistoryMixinInters0 := channelmonitorhistoryMixin[0].Interceptors()
- channelmonitorhistory.Interceptors[0] = channelmonitorhistoryMixinInters0[0]
channelmonitorhistoryFields := schema.ChannelMonitorHistory{}.Fields()
_ = channelmonitorhistoryFields
// channelmonitorhistoryDescModel is the schema descriptor for model field.
diff --git a/backend/ent/schema/channel_monitor_daily_rollup.go b/backend/ent/schema/channel_monitor_daily_rollup.go
index 574a28d9..23f032e3 100644
--- a/backend/ent/schema/channel_monitor_daily_rollup.go
+++ b/backend/ent/schema/channel_monitor_daily_rollup.go
@@ -10,13 +10,12 @@ import (
"entgo.io/ent/schema/edge"
"entgo.io/ent/schema/field"
"entgo.io/ent/schema/index"
-
- "github.com/Wei-Shaw/sub2api/ent/schema/mixins"
)
// ChannelMonitorDailyRollup 按 (monitor_id, model, bucket_date) 维度聚合的渠道监控日统计。
// 每天的明细被收敛为一行(保留 status 分布 + 延迟和),用于 7d/15d/30d 窗口的可用率
// 加权计算(avg_latency = sum_latency_ms / count_latency;availability = ok_count / total_checks)。
+// 超过保留期由每日维护任务分批物理删(不用软删除,理由同 channel_monitor_history)。
type ChannelMonitorDailyRollup struct {
ent.Schema
}
@@ -27,12 +26,6 @@ func (ChannelMonitorDailyRollup) Annotations() []schema.Annotation {
}
}
-func (ChannelMonitorDailyRollup) Mixin() []ent.Mixin {
- return []ent.Mixin{
- mixins.SoftDeleteMixin{},
- }
-}
-
func (ChannelMonitorDailyRollup) Fields() []ent.Field {
return []ent.Field{
field.Int64("monitor_id"),
diff --git a/backend/ent/schema/channel_monitor_history.go b/backend/ent/schema/channel_monitor_history.go
index ec54b34f..4366e79a 100644
--- a/backend/ent/schema/channel_monitor_history.go
+++ b/backend/ent/schema/channel_monitor_history.go
@@ -9,13 +9,12 @@ import (
"entgo.io/ent/schema/edge"
"entgo.io/ent/schema/field"
"entgo.io/ent/schema/index"
-
- "github.com/Wei-Shaw/sub2api/ent/schema/mixins"
)
// ChannelMonitorHistory holds the schema definition for the ChannelMonitorHistory entity.
-// 渠道监控历史:每次检测每个模型一行记录。明细只保留 1 天,超过 1 天的数据被聚合到
-// channel_monitor_daily_rollups 后软删(deleted_at),由后续懒清理任务物理移除。
+// 渠道监控历史:每次检测每个模型一行记录。明细只保留 1 天,超过 1 天由每日维护任务
+// 先聚合到 channel_monitor_daily_rollups,再分批物理删(不用软删除:日志类表无恢复
+// 需求,软删会让行和索引只增不减,徒增磁盘和查询开销)。
type ChannelMonitorHistory struct {
ent.Schema
}
@@ -26,12 +25,6 @@ func (ChannelMonitorHistory) Annotations() []schema.Annotation {
}
}
-func (ChannelMonitorHistory) Mixin() []ent.Mixin {
- return []ent.Mixin{
- mixins.SoftDeleteMixin{},
- }
-}
-
func (ChannelMonitorHistory) Fields() []ent.Field {
return []ent.Field{
field.Int64("monitor_id"),
diff --git a/backend/internal/repository/channel_monitor_repo.go b/backend/internal/repository/channel_monitor_repo.go
index badbdbca..f4e2a0ec 100644
--- a/backend/internal/repository/channel_monitor_repo.go
+++ b/backend/internal/repository/channel_monitor_repo.go
@@ -9,7 +9,6 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/channelmonitor"
- "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
"github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
@@ -195,15 +194,10 @@ func (r *channelMonitorRepository) InsertHistoryBatch(ctx context.Context, rows
return nil
}
+// DeleteHistoryBefore 物理删 checked_at < before 的明细,分批 channelMonitorPruneBatchSize 行一批,
+// 避免单事务删除过多引起锁/WAL 压力。借助 (checked_at) 索引定位小批 id,再按 id 删。
func (r *channelMonitorRepository) DeleteHistoryBefore(ctx context.Context, before time.Time) (int64, error) {
- client := clientFromContext(ctx, r.client)
- n, err := client.ChannelMonitorHistory.Delete().
- Where(channelmonitorhistory.CheckedAtLT(before)).
- Exec(ctx)
- if err != nil {
- return 0, fmt.Errorf("delete history before: %w", err)
- }
- return int64(n), nil
+ return deleteChannelMonitorBatched(ctx, r.db, channelMonitorPruneHistorySQL, before)
}
// ListHistory 按 checked_at 倒序返回某个监控的最近 N 条历史记录。
@@ -247,7 +241,6 @@ func (r *channelMonitorRepository) ListLatestPerModel(ctx context.Context, monit
model, status, latency_ms, ping_latency_ms, checked_at
FROM channel_monitor_histories
WHERE monitor_id = $1
- AND deleted_at IS NULL
ORDER BY model, checked_at DESC
`
rows, err := r.db.QueryContext(ctx, q, monitorID)
@@ -302,7 +295,6 @@ func (r *channelMonitorRepository) ComputeAvailability(ctx context.Context, moni
COUNT(latency_ms) AS count_latency
FROM channel_monitor_histories
WHERE monitor_id = $1
- AND deleted_at IS NULL
AND checked_at >= CURRENT_DATE
GROUP BY model
),
@@ -310,7 +302,6 @@ func (r *channelMonitorRepository) ComputeAvailability(ctx context.Context, moni
SELECT model, total_checks, ok_count, sum_latency_ms, count_latency
FROM channel_monitor_daily_rollups
WHERE monitor_id = $1
- AND deleted_at IS NULL
AND bucket_date >= (CURRENT_DATE - $2::int)
AND bucket_date < CURRENT_DATE
)
@@ -376,7 +367,6 @@ func (r *channelMonitorRepository) ListLatestForMonitorIDs(ctx context.Context,
monitor_id, model, status, latency_ms, ping_latency_ms, checked_at
FROM channel_monitor_histories
WHERE monitor_id = ANY($1)
- AND deleted_at IS NULL
ORDER BY monitor_id, model, checked_at DESC
`
rows, err := r.db.QueryContext(ctx, q, pq.Array(ids))
@@ -437,7 +427,6 @@ func (r *channelMonitorRepository) ListRecentHistoryForMonitors(
FROM channel_monitor_histories h
JOIN targets t
ON t.monitor_id = h.monitor_id AND t.model = h.model
- WHERE h.deleted_at IS NULL
)
SELECT monitor_id, status, latency_ms, ping_latency_ms, checked_at
FROM ranked
@@ -524,7 +513,6 @@ func (r *channelMonitorRepository) ComputeAvailabilityForMonitors(ctx context.Co
COUNT(latency_ms) AS count_latency
FROM channel_monitor_histories
WHERE monitor_id = ANY($1)
- AND deleted_at IS NULL
AND checked_at >= CURRENT_DATE
GROUP BY monitor_id, model
),
@@ -532,7 +520,6 @@ func (r *channelMonitorRepository) ComputeAvailabilityForMonitors(ctx context.Co
SELECT monitor_id, model, total_checks, ok_count, sum_latency_ms, count_latency
FROM channel_monitor_daily_rollups
WHERE monitor_id = ANY($1)
- AND deleted_at IS NULL
AND bucket_date >= (CURRENT_DATE - $2::int)
AND bucket_date < CURRENT_DATE
)
@@ -572,11 +559,10 @@ func (r *channelMonitorRepository) ComputeAvailabilityForMonitors(ctx context.Co
// ---------- 聚合维护 ----------
-// UpsertDailyRollupsFor 把 targetDate 当天([targetDate, targetDate+1d))未软删的明细
+// UpsertDailyRollupsFor 把 targetDate 当天([targetDate, targetDate+1d))的明细
// 按 (monitor_id, model, bucket_date) 聚合写入 channel_monitor_daily_rollups。
// - 用 ON CONFLICT (monitor_id, model, bucket_date) DO UPDATE 实现幂等回填,
// 重复执行只会用最新统计覆盖;
-// - 同时把 deleted_at 重置为 NULL,避免历史误删后聚合行被持续过滤掉;
// - $1::date 让 PG 自动把入参 truncate 到 UTC 日期,调用方不需要预处理 targetDate。
func (r *channelMonitorRepository) UpsertDailyRollupsFor(ctx context.Context, targetDate time.Time) (int64, error) {
const q = `
@@ -604,8 +590,7 @@ func (r *channelMonitorRepository) UpsertDailyRollupsFor(ctx context.Context, ta
COUNT(ping_latency_ms) AS count_ping_latency,
NOW()
FROM channel_monitor_histories
- WHERE deleted_at IS NULL
- AND checked_at >= $1::date
+ WHERE checked_at >= $1::date
AND checked_at < ($1::date + INTERVAL '1 day')
GROUP BY monitor_id, model
ON CONFLICT (monitor_id, model, bucket_date) DO UPDATE SET
@@ -619,8 +604,7 @@ func (r *channelMonitorRepository) UpsertDailyRollupsFor(ctx context.Context, ta
count_latency = EXCLUDED.count_latency,
sum_ping_latency_ms = EXCLUDED.sum_ping_latency_ms,
count_ping_latency = EXCLUDED.count_ping_latency,
- computed_at = NOW(),
- deleted_at = NULL
+ computed_at = NOW()
`
res, err := r.db.ExecContext(ctx, q, targetDate)
if err != nil {
@@ -633,17 +617,59 @@ func (r *channelMonitorRepository) UpsertDailyRollupsFor(ctx context.Context, ta
return n, nil
}
-// DeleteRollupsBefore 软删 bucket_date < beforeDate 的聚合行。
-// 走 ent client,利用 SoftDeleteMixin 把 DELETE 自动改写为 UPDATE deleted_at = NOW()。
+// DeleteRollupsBefore 物理删 bucket_date < beforeDate 的聚合行,同样分批。
func (r *channelMonitorRepository) DeleteRollupsBefore(ctx context.Context, beforeDate time.Time) (int64, error) {
- client := clientFromContext(ctx, r.client)
- n, err := client.ChannelMonitorDailyRollup.Delete().
- Where(channelmonitordailyrollup.BucketDateLT(beforeDate)).
- Exec(ctx)
- if err != nil {
- return 0, fmt.Errorf("delete rollups before: %w", err)
+ return deleteChannelMonitorBatched(ctx, r.db, channelMonitorPruneRollupSQL, beforeDate)
+}
+
+// channelMonitorPruneBatchSize 单批删除上限。与 ops_cleanup_service 保持一致的 5000,
+// 在大表上按 id 小批删可以避免长事务和 WAL 堆积。
+const channelMonitorPruneBatchSize = 5000
+
+// channelMonitorPruneHistorySQL 分批物理删明细表过期行。
+const channelMonitorPruneHistorySQL = `
+WITH batch AS (
+ SELECT id FROM channel_monitor_histories
+ WHERE checked_at < $1
+ ORDER BY id
+ LIMIT $2
+)
+DELETE FROM channel_monitor_histories
+WHERE id IN (SELECT id FROM batch)
+`
+
+// channelMonitorPruneRollupSQL 分批物理删 rollup 表过期行。bucket_date 需要 ::date 转型
+// 保证与 DATE 列一致比较。
+const channelMonitorPruneRollupSQL = `
+WITH batch AS (
+ SELECT id FROM channel_monitor_daily_rollups
+ WHERE bucket_date < $1::date
+ ORDER BY id
+ LIMIT $2
+)
+DELETE FROM channel_monitor_daily_rollups
+WHERE id IN (SELECT id FROM batch)
+`
+
+// deleteChannelMonitorBatched 循环执行分批 DELETE,直到影响行为 0。返回累计删除行数。
+// cutoff 由调用方按列类型传入(明细用 time.Time 对 TIMESTAMPTZ,rollup 用 time.Time SQL 侧 ::date 转型)。
+func deleteChannelMonitorBatched(ctx context.Context, db *sql.DB, query string, cutoff time.Time) (int64, error) {
+ var total int64
+ for {
+ res, err := db.ExecContext(ctx, query, cutoff, channelMonitorPruneBatchSize)
+ if err != nil {
+ return total, fmt.Errorf("channel_monitor prune batch: %w", err)
+ }
+ affected, err := res.RowsAffected()
+ if err != nil {
+ return total, fmt.Errorf("channel_monitor prune rows affected: %w", err)
+ }
+ total += affected
+ if affected == 0 {
+ break
+ }
}
- return int64(n), nil
+ return total, nil
}
// LoadAggregationWatermark 读 watermark 表(id=1)。
diff --git a/backend/migrations/127_drop_channel_monitor_deleted_at.sql b/backend/migrations/127_drop_channel_monitor_deleted_at.sql
new file mode 100644
index 00000000..2260f06b
--- /dev/null
+++ b/backend/migrations/127_drop_channel_monitor_deleted_at.sql
@@ -0,0 +1,16 @@
+-- Migration: 127_drop_channel_monitor_deleted_at
+-- 纠正 110 引入的 SoftDeleteMixin:日志/聚合表无恢复需求,软删会让行和索引只增不减,
+-- 徒增磁盘和查询开销。改回分批物理删(由 OpsCleanupService 每天凌晨统一调度,
+-- deleteOldRowsByID 模板,batch=5000)。
+--
+-- 110 尚未跑过聚合/清理(首次 maintenance 在次日 02:00),所以此处不担心业务数据。
+-- 直接 DROP 列 + 索引;对应的 Go 侧 ent schema 已移除 SoftDeleteMixin、repo 的
+-- raw SQL 已移除 deleted_at IS NULL 过滤。
+
+DROP INDEX IF EXISTS idx_channel_monitor_histories_deleted_at;
+ALTER TABLE channel_monitor_histories
+ DROP COLUMN IF EXISTS deleted_at;
+
+DROP INDEX IF EXISTS idx_channel_monitor_daily_rollups_deleted_at;
+ALTER TABLE channel_monitor_daily_rollups
+ DROP COLUMN IF EXISTS deleted_at;
diff --git a/frontend/src/components/layout/AppSidebar.vue b/frontend/src/components/layout/AppSidebar.vue
index 8b9fbdea..248e0021 100644
--- a/frontend/src/components/layout/AppSidebar.vue
+++ b/frontend/src/components/layout/AppSidebar.vue
@@ -199,6 +199,28 @@ interface NavItem {
* does NOT navigate to its `path`. The `path` is purely a stable key.
*/
expandOnly?: boolean
+ /**
+ * 可选的功能开关 getter。返回 false 时菜单项被隐藏;返回 undefined/true 时显示。
+ * 宽容策略(undefined → 显示)避免 public settings 未加载完成时菜单闪烁消失。
+ * Getter 里访问的 reactive 来源(store / composable)会被 computed 自动追踪,
+ * 开关切换时菜单自动更新。
+ */
+ featureFlag?: () => boolean | undefined
+}
+
+// applyFeatureFlags 递归过滤掉 featureFlag() === false 的节点(含子节点)。
+// 使用 `!== false` 宽容语义:undefined(设置未加载)或 true 都视为显示。
+function applyFeatureFlags(items: NavItem[]): NavItem[] {
+ const out: NavItem[] = []
+ for (const item of items) {
+ if (item.featureFlag && item.featureFlag() === false) continue
+ if (item.children) {
+ out.push({ ...item, children: applyFeatureFlags(item.children) })
+ } else {
+ out.push(item)
+ }
+ }
+ return out
}
const { t } = useI18n()
@@ -605,36 +627,27 @@ const ChevronDownIcon = {
)
}
-// User navigation items (for regular users)
-const userNavItems = computed((): NavItem[] => {
- const items: NavItem[] = [
- { path: '/dashboard', label: t('nav.dashboard'), icon: DashboardIcon },
+// 各个开关集中声明:所有菜单项引用这里的 getter,未来加新开关只需在此加一个常量。
+// getter 返回 false = 隐藏;undefined/true = 显示(宽容策略,避免 public settings 未加载闪烁)。
+const flagChannelMonitor = () => appStore.cachedPublicSettings?.channel_monitor_enabled
+const flagPayment = () => appStore.cachedPublicSettings?.payment_enabled
+const flagOpsMonitoring = () => adminSettingsStore.opsMonitoringEnabled
+const flagAdminPayment = () => adminSettingsStore.paymentEnabled
+
+// buildSelfNavItems 构造用户自己的导航项(用户端主菜单和管理员的"我的账户"子菜单共享这组声明)。
+// withDashboard=true 时包含仪表盘(用户端),false 时不含(管理员的个人区已经有独立仪表盘入口)。
+function buildSelfNavItems(withDashboard: boolean): NavItem[] {
+ const items: NavItem[] = []
+ if (withDashboard) {
+ items.push({ path: '/dashboard', label: t('nav.dashboard'), icon: DashboardIcon })
+ }
+ items.push(
{ path: '/keys', label: t('nav.apiKeys'), icon: KeyIcon },
{ path: '/usage', label: t('nav.usage'), icon: ChartIcon, hideInSimpleMode: true },
- ...(appStore.cachedPublicSettings?.channel_monitor_enabled
- ? [{ path: '/monitor', label: t('nav.channelStatus'), icon: SignalIcon }]
- : []),
+ { path: '/monitor', label: t('nav.channelStatus'), icon: SignalIcon, featureFlag: flagChannelMonitor },
{ path: '/subscriptions', label: t('nav.mySubscriptions'), icon: CreditCardIcon, hideInSimpleMode: true },
- ...(appStore.cachedPublicSettings?.payment_enabled
- ? [
- {
- path: '/purchase',
- label: t('nav.buySubscription'),
- icon: RechargeSubscriptionIcon,
- hideInSimpleMode: true
- },
- ]
- : []),
- ...(appStore.cachedPublicSettings?.payment_enabled
- ? [
- {
- path: '/orders',
- label: t('nav.myOrders'),
- icon: OrderListIcon,
- hideInSimpleMode: true
- },
- ]
- : []),
+ { path: '/purchase', label: t('nav.buySubscription'), icon: RechargeSubscriptionIcon, hideInSimpleMode: true, featureFlag: flagPayment },
+ { path: '/orders', label: t('nav.myOrders'), icon: OrderListIcon, hideInSimpleMode: true, featureFlag: flagPayment },
{ path: '/redeem', label: t('nav.redeem'), icon: GiftIcon, hideInSimpleMode: true },
{ path: '/profile', label: t('nav.profile'), icon: UserIcon },
...customMenuItemsForUser.value.map((item): NavItem => ({
@@ -643,50 +656,21 @@ const userNavItems = computed((): NavItem[] => {
icon: null,
iconSvg: item.icon_svg,
})),
- ]
- return authStore.isSimpleMode ? items.filter(item => !item.hideInSimpleMode) : items
-})
+ )
+ return items
+}
+
+// finalizeNav 合并三重过滤:featureFlag 过滤 + simple 模式过滤。
+function finalizeNav(items: NavItem[]): NavItem[] {
+ const visible = applyFeatureFlags(items)
+ return authStore.isSimpleMode ? visible.filter(item => !item.hideInSimpleMode) : visible
+}
+
+// User navigation items (for regular users)
+const userNavItems = computed((): NavItem[] => finalizeNav(buildSelfNavItems(true)))
// Personal navigation items (for admin's "My Account" section, without Dashboard)
-const personalNavItems = computed((): NavItem[] => {
- const items: NavItem[] = [
- { path: '/keys', label: t('nav.apiKeys'), icon: KeyIcon },
- { path: '/usage', label: t('nav.usage'), icon: ChartIcon, hideInSimpleMode: true },
- ...(appStore.cachedPublicSettings?.channel_monitor_enabled
- ? [{ path: '/monitor', label: t('nav.channelStatus'), icon: SignalIcon }]
- : []),
- { path: '/subscriptions', label: t('nav.mySubscriptions'), icon: CreditCardIcon, hideInSimpleMode: true },
- ...(appStore.cachedPublicSettings?.payment_enabled
- ? [
- {
- path: '/purchase',
- label: t('nav.buySubscription'),
- icon: RechargeSubscriptionIcon,
- hideInSimpleMode: true
- },
- ]
- : []),
- ...(appStore.cachedPublicSettings?.payment_enabled
- ? [
- {
- path: '/orders',
- label: t('nav.myOrders'),
- icon: OrderListIcon,
- hideInSimpleMode: true
- },
- ]
- : []),
- { path: '/redeem', label: t('nav.redeem'), icon: GiftIcon, hideInSimpleMode: true },
- { path: '/profile', label: t('nav.profile'), icon: UserIcon },
- ...customMenuItemsForUser.value.map((item): NavItem => ({
- path: `/custom/${item.id}`,
- label: item.label,
- icon: null,
- iconSvg: item.icon_svg,
- })),
- ]
- return authStore.isSimpleMode ? items.filter(item => !item.hideInSimpleMode) : items
-})
+const personalNavItems = computed((): NavItem[] => finalizeNav(buildSelfNavItems(false)))
// Custom menu items filtered by visibility
const customMenuItemsForUser = computed(() => {
@@ -706,9 +690,7 @@ const customMenuItemsForAdmin = computed(() => {
const adminNavItems = computed((): NavItem[] => {
const baseItems: NavItem[] = [
{ path: '/admin/dashboard', label: t('nav.dashboard'), icon: DashboardIcon },
- ...(adminSettingsStore.opsMonitoringEnabled
- ? [{ path: '/admin/ops', label: t('nav.ops'), icon: ChartIcon }]
- : []),
+ { path: '/admin/ops', label: t('nav.ops'), icon: ChartIcon, featureFlag: flagOpsMonitoring },
{ path: '/admin/users', label: t('nav.users'), icon: UsersIcon, hideInSimpleMode: true },
{ path: '/admin/groups', label: t('nav.groups'), icon: FolderIcon, hideInSimpleMode: true },
{
@@ -719,9 +701,7 @@ const adminNavItems = computed((): NavItem[] => {
expandOnly: true,
children: [
{ path: '/admin/channels/pricing', label: t('nav.channelPricing'), icon: PriceTagIcon },
- ...(appStore.cachedPublicSettings?.channel_monitor_enabled
- ? [{ path: '/admin/channels/monitor', label: t('nav.channelMonitor'), icon: SignalIcon }]
- : []),
+ { path: '/admin/channels/monitor', label: t('nav.channelMonitor'), icon: SignalIcon, featureFlag: flagChannelMonitor },
],
},
{ path: '/admin/subscriptions', label: t('nav.subscriptions'), icon: CreditCardIcon, hideInSimpleMode: true },
@@ -730,43 +710,40 @@ const adminNavItems = computed((): NavItem[] => {
{ path: '/admin/proxies', label: t('nav.proxies'), icon: ServerIcon },
{ path: '/admin/redeem', label: t('nav.redeemCodes'), icon: TicketIcon, hideInSimpleMode: true },
{ path: '/admin/promo-codes', label: t('nav.promoCodes'), icon: GiftIcon, hideInSimpleMode: true },
- ...(adminSettingsStore.paymentEnabled
- ? [
- {
- path: '/admin/orders',
- label: t('nav.orderManagement'),
- icon: OrderIcon,
- hideInSimpleMode: true,
- expandOnly: true,
- children: [
- { path: '/admin/orders/dashboard', label: t('nav.paymentDashboard'), icon: ChartIcon },
- { path: '/admin/orders', label: t('nav.orderManagement'), icon: OrderIcon },
- { path: '/admin/orders/plans', label: t('nav.paymentPlans'), icon: CreditCardIcon },
- ],
- },
- ]
- : []),
+ {
+ path: '/admin/orders',
+ label: t('nav.orderManagement'),
+ icon: OrderIcon,
+ hideInSimpleMode: true,
+ expandOnly: true,
+ featureFlag: flagAdminPayment,
+ children: [
+ { path: '/admin/orders/dashboard', label: t('nav.paymentDashboard'), icon: ChartIcon },
+ { path: '/admin/orders', label: t('nav.orderManagement'), icon: OrderIcon },
+ { path: '/admin/orders/plans', label: t('nav.paymentPlans'), icon: CreditCardIcon },
+ ],
+ },
{ path: '/admin/usage', label: t('nav.usage'), icon: ChartIcon }
]
+ const visible = applyFeatureFlags(baseItems)
+
// 简单模式下,在系统设置前插入 API密钥
if (authStore.isSimpleMode) {
- const filtered = baseItems.filter(item => !item.hideInSimpleMode)
+ const filtered = visible.filter(item => !item.hideInSimpleMode)
filtered.push({ path: '/keys', label: t('nav.apiKeys'), icon: KeyIcon })
filtered.push({ path: '/admin/settings', label: t('nav.settings'), icon: CogIcon })
- // Add admin custom menu items after settings
for (const cm of customMenuItemsForAdmin.value) {
filtered.push({ path: `/custom/${cm.id}`, label: cm.label, icon: null, iconSvg: cm.icon_svg })
}
return filtered
}
- baseItems.push({ path: '/admin/settings', label: t('nav.settings'), icon: CogIcon })
- // Add admin custom menu items after settings
+ visible.push({ path: '/admin/settings', label: t('nav.settings'), icon: CogIcon })
for (const cm of customMenuItemsForAdmin.value) {
- baseItems.push({ path: `/custom/${cm.id}`, label: cm.label, icon: null, iconSvg: cm.icon_svg })
+ visible.push({ path: `/custom/${cm.id}`, label: cm.label, icon: null, iconSvg: cm.icon_svg })
}
- return baseItems
+ return visible
})
function toggleSidebar() {
--
GitLab
From b363bff1d8024cae552fe494bc49195f238e2a02 Mon Sep 17 00:00:00 2001
From: erio
Date: Tue, 21 Apr 2026 11:59:11 +0800
Subject: [PATCH 090/261] feat(channel-monitor): preserve upstream error body
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Monitor:
- callProvider now returns both textPath-extracted text and raw body;
runCheckForModel uses rawBody on non-2xx so history.message stops being
"upstream HTTP 503: " with empty body (gjson textPath produces "" for
error responses like {"error":{"message":"No available accounts..."}})
- truncateForErrorBody collapses whitespace then caps at 300 bytes
(monitorErrorBodySnippetMaxBytes); final truncateMessage still enforces
the 500-byte DB column cap
Frontend:
- MonitorFormDialog: primary_model input text color and ModelTagInput tags
now both track form.provider (via new getPlatformTextClass + existing
getPlatformTagClass with platform prop).
(cherry-picked from 1d3b0418; dropped gateway_handler logging改动,不在本 PR 范围)
---
.../service/channel_monitor_checker.go | 42 +++++++++++++++----
.../internal/service/channel_monitor_const.go | 4 ++
.../src/components/admin/channel/types.ts | 11 +++++
.../admin/monitor/MonitorFormDialog.vue | 11 ++++-
4 files changed, 58 insertions(+), 10 deletions(-)
diff --git a/backend/internal/service/channel_monitor_checker.go b/backend/internal/service/channel_monitor_checker.go
index ba5ce0e8..e03c2e3a 100644
--- a/backend/internal/service/channel_monitor_checker.go
+++ b/backend/internal/service/channel_monitor_checker.go
@@ -49,7 +49,7 @@ func runCheckForModel(ctx context.Context, provider, endpoint, apiKey, model str
challenge := generateChallenge()
start := time.Now()
- respText, statusCode, err := callProvider(ctx, provider, endpoint, apiKey, model, challenge.Prompt)
+ respText, rawBody, statusCode, err := callProvider(ctx, provider, endpoint, apiKey, model, challenge.Prompt)
latency := time.Since(start)
latencyMs := int(latency / time.Millisecond)
res.LatencyMs = &latencyMs
@@ -60,8 +60,11 @@ func runCheckForModel(ctx context.Context, provider, endpoint, apiKey, model str
return res
}
if statusCode < 200 || statusCode >= 300 {
+ // 错误路径:用 rawBody 而非 respText(gjson textPath 抽取在错误响应里通常为空,
+ // 会丢掉真正的上游错误信息,例如 `{"error":{"message":"No available accounts ..."}}`)。
res.Status = MonitorStatusError
- res.Message = truncateMessage(sanitizeErrorMessage(fmt.Sprintf("upstream HTTP %d: %s", statusCode, respText)))
+ bodySnippet := truncateForErrorBody(rawBody)
+ res.Message = truncateMessage(sanitizeErrorMessage(fmt.Sprintf("upstream HTTP %d: %s", statusCode, bodySnippet)))
return res
}
@@ -180,22 +183,27 @@ func isSupportedProvider(p string) bool {
}
// callProvider 通过 providerAdapters 分发到具体实现。
-// 返回值:响应中提取的文本、HTTP status、网络/序列化错误。
-func callProvider(ctx context.Context, provider, endpoint, apiKey, model, prompt string) (string, int, error) {
+//
+// 返回值:
+// - extractedText: 按 textPath 抽出的成功文本,仅在 status 2xx 时有意义;非 2xx 时通常为空串
+// - rawBody: 完整响应体的字符串形式(已被 monitorResponseMaxBytes 截断),用于错误路径保留上游真实回包
+// - status: HTTP 状态码
+// - err: 网络 / 序列化错误
+func callProvider(ctx context.Context, provider, endpoint, apiKey, model, prompt string) (extractedText, rawBody string, status int, err error) {
adapter, ok := providerAdapters[provider]
if !ok {
- return "", 0, fmt.Errorf("unsupported provider %q", provider)
+ return "", "", 0, fmt.Errorf("unsupported provider %q", provider)
}
body, err := adapter.buildBody(model, prompt)
if err != nil {
- return "", 0, fmt.Errorf("marshal body: %w", err)
+ return "", "", 0, fmt.Errorf("marshal body: %w", err)
}
full := joinURL(endpoint, adapter.buildPath(model))
- respBody, status, err := postRawJSON(ctx, full, body, adapter.buildHeaders(apiKey))
+ respBytes, status, err := postRawJSON(ctx, full, body, adapter.buildHeaders(apiKey))
if err != nil {
- return "", status, err
+ return "", "", status, err
}
- return gjson.GetBytes(respBody, adapter.textPath).String(), status, nil
+ return gjson.GetBytes(respBytes, adapter.textPath).String(), string(respBytes), status, nil
}
// postRawJSON 发送 POST + 已序列化好的 JSON 字节,限制响应体大小,返回响应字节、HTTP status、错误。
@@ -297,3 +305,19 @@ func truncateMessage(msg string) string {
}
return msg[:cutoff] + ellipsis
}
+
+// truncateForErrorBody 把上游错误响应 body 压到 monitorErrorBodySnippetMaxBytes 以内,
+// 并顺手把连续空白折成一个空格:上游 HTML 错误页常含大量缩进/换行,保留会浪费预算。
+// 被 truncateMessage 做最终总截断兜底,所以这里只负责 body 自身的精简。
+func truncateForErrorBody(body string) string {
+ body = strings.Join(strings.Fields(body), " ")
+ if len(body) <= monitorErrorBodySnippetMaxBytes {
+ return body
+ }
+ const ellipsis = "...(body truncated)"
+ cutoff := monitorErrorBodySnippetMaxBytes - len(ellipsis)
+ if cutoff < 0 {
+ cutoff = 0
+ }
+ return body[:cutoff] + ellipsis
+}
diff --git a/backend/internal/service/channel_monitor_const.go b/backend/internal/service/channel_monitor_const.go
index b61f3bdd..768a432f 100644
--- a/backend/internal/service/channel_monitor_const.go
+++ b/backend/internal/service/channel_monitor_const.go
@@ -36,6 +36,10 @@ const (
monitorMessageMaxBytes = 500
// monitorResponseMaxBytes 单次模型响应最大读取字节,防止 OOM。
monitorResponseMaxBytes = 64 * 1024
+ // monitorErrorBodySnippetMaxBytes 非 2xx 响应时保留上游 body 片段的最大字节数。
+ // 留 300 字节足够覆盖典型结构化错误(如 `{"error":{"message":"..."}}`),
+ // 又给 "upstream HTTP : " 前缀留出余量,避免最终被 monitorMessageMaxBytes (500) 截得太狠。
+ monitorErrorBodySnippetMaxBytes = 300
// monitorChallengeMin / monitorChallengeMax challenge 操作数范围。
monitorChallengeMin = 1
monitorChallengeMax = 50
diff --git a/frontend/src/components/admin/channel/types.ts b/frontend/src/components/admin/channel/types.ts
index b3966289..955b6487 100644
--- a/frontend/src/components/admin/channel/types.ts
+++ b/frontend/src/components/admin/channel/types.ts
@@ -187,3 +187,14 @@ export function getPlatformTagClass(platform: string): string {
default: return 'bg-gray-100 text-gray-700 dark:bg-gray-900/30 dark:text-gray-400'
}
}
+
+/** 平台对应的模型文字色(仅 text-*,用于 input/text 场景)— 与 getPlatformTagClass 同色系 */
+export function getPlatformTextClass(platform: string): string {
+ switch (platform) {
+ case 'anthropic': return 'text-orange-700 dark:text-orange-400'
+ case 'openai': return 'text-emerald-700 dark:text-emerald-400'
+ case 'gemini': return 'text-blue-700 dark:text-blue-400'
+ case 'antigravity': return 'text-purple-700 dark:text-purple-400'
+ default: return ''
+ }
+}
diff --git a/frontend/src/components/admin/monitor/MonitorFormDialog.vue b/frontend/src/components/admin/monitor/MonitorFormDialog.vue
index 836ec079..4a538fcf 100644
--- a/frontend/src/components/admin/monitor/MonitorFormDialog.vue
+++ b/frontend/src/components/admin/monitor/MonitorFormDialog.vue
@@ -60,13 +60,21 @@
{{ t('admin.channelMonitor.form.primaryModel') }} *
-
+
{{ t('admin.channelMonitor.form.extraModels') }}
@@ -137,6 +145,7 @@ import type { ApiKey } from '@/types'
import BaseDialog from '@/components/common/BaseDialog.vue'
import Toggle from '@/components/common/Toggle.vue'
import ModelTagInput from '@/components/admin/channel/ModelTagInput.vue'
+import { getPlatformTextClass } from '@/components/admin/channel/types'
import MonitorKeyPickerDialog from '@/components/admin/monitor/MonitorKeyPickerDialog.vue'
import ProviderIcon from '@/components/user/monitor/ProviderIcon.vue'
import { useChannelMonitorFormat } from '@/composables/useChannelMonitorFormat'
--
GitLab
From 0c48f08f5c748bdec005e31344a6d672b1d8ad83 Mon Sep 17 00:00:00 2001
From: erio
Date: Tue, 21 Apr 2026 12:12:08 +0800
Subject: [PATCH 091/261] refactor(channel-status): drop breadcrumb + subtitle
from MonitorHero
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
The "CHANNEL · STATUS" breadcrumb and the zh/en subtitles above the
window-picker were redundant with the existing "渠道状态" page title
shown in the layout header. Remove the left column and right-align the
7d/15d/30d tabs + overall chip.
Also drop the now-unreferenced channelStatus.hero.* i18n keys from both
locales (grep confirms no remaining usage).
chore: bump version to 0.1.114.31
---
.../src/components/user/monitor/MonitorHero.vue | 14 +-------------
frontend/src/i18n/locales/en.ts | 5 -----
frontend/src/i18n/locales/zh.ts | 5 -----
3 files changed, 1 insertion(+), 23 deletions(-)
diff --git a/frontend/src/components/user/monitor/MonitorHero.vue b/frontend/src/components/user/monitor/MonitorHero.vue
index 6857a6fe..7fc4d846 100644
--- a/frontend/src/components/user/monitor/MonitorHero.vue
+++ b/frontend/src/components/user/monitor/MonitorHero.vue
@@ -1,18 +1,6 @@
-
- {{ t('channelStatus.hero.breadcrumb') }}
-
-
-
-
- {{ t('channelStatus.hero.subtitleZh') }}
-
-
- {{ t('channelStatus.hero.subtitleEn') }}
-
-
-
+
Date: Tue, 21 Apr 2026 14:14:49 +0800
Subject: [PATCH 092/261] feat(channel-monitor): request templates with
snapshot apply + headers/body override
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Problem:
Upstream channels can reject monitor probes based on client fingerprint
(e.g. "only Claude Code clients allowed"). The monitor had no way to
customize the outgoing request to bypass such restrictions.
Solution:
Introduce reusable request templates that carry extra_headers plus an
optional body override; monitors reference a template and receive a
snapshot copy on apply. Template edits do NOT auto-propagate — users
must click "apply to associated monitors" to refresh snapshots, so a
bad template edit cannot instantly break all production monitors.
Data model (migration 112):
- channel_monitor_request_templates: id, name, provider, description,
extra_headers jsonb, body_override_mode ('off'|'merge'|'replace'),
body_override jsonb. Unique (provider, name).
- channel_monitors: +template_id (FK, ON DELETE SET NULL), +extra_headers,
+body_override_mode, +body_override (the three runtime snapshot fields).
Checker (channel_monitor_checker.go):
- callProvider + runCheckForModel accept a CheckOptions carrying the
snapshot fields. mergeHeaders applies user headers on top of adapter
defaults (forbidden list: Host / Content-Length / Transfer-Encoding /
Connection / Content-Encoding).
- buildRequestBody:
off -> adapter default body
merge -> shallow-merge over default; per-provider deny list
(model/messages/contents) protects the challenge contract
replace -> user body verbatim
- Replace mode skips challenge validation; instead HTTP 2xx + non-empty
extracted response text = operational, empty = failed.
- 4 new unit tests cover all three modes + replace/empty-response case.
Admin API:
- /admin/channel-monitor-templates CRUD + /:id/apply (overwrite snapshot
on all template_id=id monitors, returns affected count).
- channel_monitor request/response DTOs gain the 4 new fields.
Frontend:
- channelMonitorTemplate.ts API client.
- MonitorAdvancedRequestConfig.vue shared component for headers textarea
+ body mode radio + body JSON editor; used by both template and monitor
forms.
- MonitorTemplateManagerDialog.vue: provider tabs, list/create/edit/
delete/apply, live "associated monitors" count per row.
- MonitorFiltersBar: new 模板管理 button next to 新增监控.
- MonitorFormDialog: collapsible 高级 section with template dropdown
(filtered by form.provider, clears on provider change) + embedded
AdvancedRequestConfig. Picking a template copies its fields into the
form (snapshot semantics mirrored on the client).
- i18n zh/en entries for all new copy.
chore: bump version to 0.1.114.32
---
_parse_upstream.py | 78 +
backend/cmd/server/wire_gen.go | 5 +-
backend/ent/channelmonitor.go | 78 +-
backend/ent/channelmonitor/channelmonitor.go | 51 +
backend/ent/channelmonitor/where.go | 138 ++
backend/ent/channelmonitor_create.go | 308 ++++
backend/ent/channelmonitor_query.go | 108 +-
backend/ent/channelmonitor_update.go | 247 ++++
backend/ent/channelmonitorrequesttemplate.go | 216 +++
.../channelmonitorrequesttemplate.go | 172 +++
.../channelmonitorrequesttemplate/where.go | 434 ++++++
.../channelmonitorrequesttemplate_create.go | 942 ++++++++++++
.../channelmonitorrequesttemplate_delete.go | 88 ++
.../channelmonitorrequesttemplate_query.go | 648 +++++++++
.../channelmonitorrequesttemplate_update.go | 639 ++++++++
backend/ent/client.go | 347 +++--
backend/ent/ent.go | 68 +-
backend/ent/hook/hook.go | 12 +
backend/ent/intercept/intercept.go | 30 +
backend/ent/migrate/schema.go | 47 +
backend/ent/mutation.go | 1285 ++++++++++++++++-
backend/ent/predicate/predicate.go | 3 +
backend/ent/runtime/runtime.go | 60 +
backend/ent/schema/channel_monitor.go | 27 +
.../channel_monitor_request_template.go | 80 +
backend/ent/tx.go | 3 +
.../handler/admin/channel_monitor_handler.go | 105 +-
.../admin/channel_monitor_template_handler.go | 195 +++
backend/internal/handler/handler.go | 55 +-
backend/internal/handler/wire.go | 57 +-
.../repository/channel_monitor_repo.go | 83 +-
.../channel_monitor_template_repo.go | 168 +++
backend/internal/repository/wire.go | 1 +
backend/internal/server/routes/admin.go | 10 +
.../service/channel_monitor_checker.go | 134 +-
.../channel_monitor_checker_body_test.go | 173 +++
.../service/channel_monitor_service.go | 72 +-
.../channel_monitor_template_service.go | 225 +++
.../service/channel_monitor_template_types.go | 74 +
.../internal/service/channel_monitor_types.go | 51 +-
backend/internal/service/wire.go | 1 +
..._add_channel_monitor_request_templates.sql | 70 +
frontend/src/api/admin/channelMonitor.ts | 16 +-
.../src/api/admin/channelMonitorTemplate.ts | 108 ++
frontend/src/api/admin/index.ts | 3 +
.../monitor/MonitorAdvancedRequestConfig.vue | 205 +++
.../admin/monitor/MonitorFiltersBar.vue | 9 +
.../admin/monitor/MonitorFormDialog.vue | 118 +-
.../monitor/MonitorTemplateManagerDialog.vue | 465 ++++++
.../components/user/monitor/MonitorHero.vue | 87 +-
frontend/src/i18n/locales/en.ts | 52 +-
frontend/src/i18n/locales/zh.ts | 52 +-
.../src/views/admin/ChannelMonitorView.vue | 9 +
53 files changed, 8318 insertions(+), 394 deletions(-)
create mode 100644 _parse_upstream.py
create mode 100644 backend/ent/channelmonitorrequesttemplate.go
create mode 100644 backend/ent/channelmonitorrequesttemplate/channelmonitorrequesttemplate.go
create mode 100644 backend/ent/channelmonitorrequesttemplate/where.go
create mode 100644 backend/ent/channelmonitorrequesttemplate_create.go
create mode 100644 backend/ent/channelmonitorrequesttemplate_delete.go
create mode 100644 backend/ent/channelmonitorrequesttemplate_query.go
create mode 100644 backend/ent/channelmonitorrequesttemplate_update.go
create mode 100644 backend/ent/schema/channel_monitor_request_template.go
create mode 100644 backend/internal/handler/admin/channel_monitor_template_handler.go
create mode 100644 backend/internal/repository/channel_monitor_template_repo.go
create mode 100644 backend/internal/service/channel_monitor_checker_body_test.go
create mode 100644 backend/internal/service/channel_monitor_template_service.go
create mode 100644 backend/internal/service/channel_monitor_template_types.go
create mode 100644 backend/migrations/128_add_channel_monitor_request_templates.sql
create mode 100644 frontend/src/api/admin/channelMonitorTemplate.ts
create mode 100644 frontend/src/components/admin/monitor/MonitorAdvancedRequestConfig.vue
create mode 100644 frontend/src/components/admin/monitor/MonitorTemplateManagerDialog.vue
diff --git a/_parse_upstream.py b/_parse_upstream.py
new file mode 100644
index 00000000..807d1cac
--- /dev/null
+++ b/_parse_upstream.py
@@ -0,0 +1,78 @@
+"""
+严格按模型拆分 upstream 的 token 和 quota;并按【我们的定价表】重算每个模型的 token 应得金额。
+对比 upstream provider-side (/group_ratio) 与我们 Anthropic 官方价的计算结果。
+"""
+import re
+import json
+from collections import defaultdict
+
+# 按账号(token_name) + 模型拆
+by_key = defaultdict(lambda: {
+ 'count': 0,
+ 'prompt': 0,
+ 'completion': 0,
+ 'cache_create': 0,
+ 'cache_read': 0,
+ 'quota_pre_group_sum': 0.0,
+ 'flat_price_reqs': 0,
+ 'flat_price_value': 0.0,
+ 'model_ratios': set(),
+ 'model_prices': set(),
+})
+
+with open(r"C:\Users\16790\xwechat_files\wxid_8tc8tfooo5rs22_fef8\msg\file\2026-04\asakifeng_consume.txt", 'r', encoding='utf-8') as f:
+ for line in f:
+ m = re.match(r'\[INFO\] (\d{4}/\d{2}/\d{2} - \d{2}:\d{2}:\d{2}) \|.*params=(\{.*\})\s*$', line.strip())
+ if not m: continue
+ try: p = json.loads(m.group(2))
+ except Exception: continue
+ tn = p.get('token_name', '')
+ model = p.get('model_name', '')
+ other = p.get('other') or {}
+ gr = other.get('group_ratio', 1.0) or 1.0
+ q = p.get('quota', 0) or 0
+ k = (tn, model)
+ d = by_key[k]
+ d['count'] += 1
+ d['prompt'] += p.get('prompt_tokens', 0) or 0
+ d['completion'] += p.get('completion_tokens', 0) or 0
+ d['cache_create'] += other.get('cache_creation_tokens', 0) or 0
+ d['cache_read'] += other.get('cache_tokens', 0) or 0
+ d['quota_pre_group_sum'] += q / gr if gr else q
+ mp = other.get('model_price') or 0
+ mr = other.get('model_ratio')
+ if mr is not None: d['model_ratios'].add(mr)
+ d['model_prices'].add(mp)
+ if mp and mp > 0:
+ d['flat_price_reqs'] += 1
+ d['flat_price_value'] += mp # flat $ per request
+
+# 我们定价表(从 backend/resources/.../model_prices_and_context_window.json 读的真实值)
+OUR_PRICE = {
+ 'claude-haiku-4-5-20251001': {'input': 1e-6, 'output': 5e-6, 'cc5m': 1.25e-6, 'cr': 1e-7},
+ 'claude-sonnet-4-6': {'input': 3e-6, 'output': 1.5e-5, 'cc5m': 3.75e-6, 'cr': 3e-7},
+ 'claude-sonnet-4-5-20250929': {'input': 3e-6, 'output': 1.5e-5, 'cc5m': 3.75e-6, 'cr': 3e-7},
+ 'claude-opus-4-6': {'input': 5e-6, 'output': 2.5e-5, 'cc5m': 6.25e-6, 'cr': 5e-7},
+ 'claude-opus-4-5-20251101': {'input': 5e-6, 'output': 2.5e-5, 'cc5m': 6.25e-6, 'cr': 5e-7},
+ 'claude-opus-4-7': {'input': 5e-6, 'output': 2.5e-5, 'cc5m': 6.25e-6, 'cr': 5e-7}, # 我们回退到 opus-4-6 价
+}
+
+print("%-40s %-28s %5s %12s %12s %12s %12s" % ("TOKEN", "MODEL", "req", "upstream$", "our_calc$", "diff$", "note"))
+print("-" * 150)
+total_up = 0.0; total_ours = 0.0
+for (tn, model), d in sorted(by_key.items()):
+ up = d['quota_pre_group_sum'] / 500000
+ p = OUR_PRICE.get(model)
+ if p:
+ ours = (d['prompt']*p['input'] + d['completion']*p['output']
+ + d['cache_create']*p['cc5m'] + d['cache_read']*p['cr'])
+ else:
+ ours = 0.0
+ diff = up - ours
+ note = ""
+ if d['flat_price_reqs']:
+ note = f"flat_price {d['flat_price_reqs']}/{d['count']}"
+ total_up += up; total_ours += ours
+ print("%-40s %-28s %5d %12.4f %12.4f %+12.4f %s" % (tn[:40], model, d['count'], up, ours, diff, note))
+print("-" * 150)
+print("%-40s %-28s %5s %12.4f %12.4f %+12.4f" % ("TOTAL", "", "", total_up, total_ours, total_up - total_ours))
diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go
index a878ea68..4e95035a 100644
--- a/backend/cmd/server/wire_gen.go
+++ b/backend/cmd/server/wire_gen.go
@@ -215,6 +215,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
return nil, err
}
channelMonitorRepository := repository.NewChannelMonitorRepository(client, sqlDB)
+ channelMonitorRequestTemplateRepository := repository.NewChannelMonitorRequestTemplateRepository(client, sqlDB)
+ channelMonitorRequestTemplateService := service.NewChannelMonitorRequestTemplateService(channelMonitorRequestTemplateRepository)
+ channelMonitorRequestTemplateHandler := admin.NewChannelMonitorRequestTemplateHandler(channelMonitorRequestTemplateService)
channelMonitorService := service.ProvideChannelMonitorService(channelMonitorRepository, secretEncryptor)
channelMonitorHandler := admin.NewChannelMonitorHandler(channelMonitorService)
channelMonitorUserHandler := handler.NewChannelMonitorUserHandler(channelMonitorService, settingService)
@@ -231,7 +234,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService)
paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService)
- adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, paymentHandler)
+ adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, paymentHandler)
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
diff --git a/backend/ent/channelmonitor.go b/backend/ent/channelmonitor.go
index 58886884..dbb73362 100644
--- a/backend/ent/channelmonitor.go
+++ b/backend/ent/channelmonitor.go
@@ -11,6 +11,7 @@ import (
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate"
)
// ChannelMonitor is the model entity for the ChannelMonitor schema.
@@ -44,6 +45,14 @@ type ChannelMonitor struct {
LastCheckedAt *time.Time `json:"last_checked_at,omitempty"`
// CreatedBy holds the value of the "created_by" field.
CreatedBy int64 `json:"created_by,omitempty"`
+ // TemplateID holds the value of the "template_id" field.
+ TemplateID *int64 `json:"template_id,omitempty"`
+ // ExtraHeaders holds the value of the "extra_headers" field.
+ ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
+ // BodyOverrideMode holds the value of the "body_override_mode" field.
+ BodyOverrideMode string `json:"body_override_mode,omitempty"`
+ // BodyOverride holds the value of the "body_override" field.
+ BodyOverride map[string]interface{} `json:"body_override,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the ChannelMonitorQuery when eager-loading is set.
Edges ChannelMonitorEdges `json:"edges"`
@@ -56,9 +65,11 @@ type ChannelMonitorEdges struct {
History []*ChannelMonitorHistory `json:"history,omitempty"`
// DailyRollups holds the value of the daily_rollups edge.
DailyRollups []*ChannelMonitorDailyRollup `json:"daily_rollups,omitempty"`
+ // RequestTemplate holds the value of the request_template edge.
+ RequestTemplate *ChannelMonitorRequestTemplate `json:"request_template,omitempty"`
// loadedTypes holds the information for reporting if a
// type was loaded (or requested) in eager-loading or not.
- loadedTypes [2]bool
+ loadedTypes [3]bool
}
// HistoryOrErr returns the History value or an error if the edge
@@ -79,18 +90,29 @@ func (e ChannelMonitorEdges) DailyRollupsOrErr() ([]*ChannelMonitorDailyRollup,
return nil, &NotLoadedError{edge: "daily_rollups"}
}
+// RequestTemplateOrErr returns the RequestTemplate value or an error if the edge
+// was not loaded in eager-loading, or loaded but was not found.
+func (e ChannelMonitorEdges) RequestTemplateOrErr() (*ChannelMonitorRequestTemplate, error) {
+ if e.RequestTemplate != nil {
+ return e.RequestTemplate, nil
+ } else if e.loadedTypes[2] {
+ return nil, &NotFoundError{label: channelmonitorrequesttemplate.Label}
+ }
+ return nil, &NotLoadedError{edge: "request_template"}
+}
+
// scanValues returns the types for scanning values from sql.Rows.
func (*ChannelMonitor) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
- case channelmonitor.FieldExtraModels:
+ case channelmonitor.FieldExtraModels, channelmonitor.FieldExtraHeaders, channelmonitor.FieldBodyOverride:
values[i] = new([]byte)
case channelmonitor.FieldEnabled:
values[i] = new(sql.NullBool)
- case channelmonitor.FieldID, channelmonitor.FieldIntervalSeconds, channelmonitor.FieldCreatedBy:
+ case channelmonitor.FieldID, channelmonitor.FieldIntervalSeconds, channelmonitor.FieldCreatedBy, channelmonitor.FieldTemplateID:
values[i] = new(sql.NullInt64)
- case channelmonitor.FieldName, channelmonitor.FieldProvider, channelmonitor.FieldEndpoint, channelmonitor.FieldAPIKeyEncrypted, channelmonitor.FieldPrimaryModel, channelmonitor.FieldGroupName:
+ case channelmonitor.FieldName, channelmonitor.FieldProvider, channelmonitor.FieldEndpoint, channelmonitor.FieldAPIKeyEncrypted, channelmonitor.FieldPrimaryModel, channelmonitor.FieldGroupName, channelmonitor.FieldBodyOverrideMode:
values[i] = new(sql.NullString)
case channelmonitor.FieldCreatedAt, channelmonitor.FieldUpdatedAt, channelmonitor.FieldLastCheckedAt:
values[i] = new(sql.NullTime)
@@ -196,6 +218,35 @@ func (_m *ChannelMonitor) assignValues(columns []string, values []any) error {
} else if value.Valid {
_m.CreatedBy = value.Int64
}
+ case channelmonitor.FieldTemplateID:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field template_id", values[i])
+ } else if value.Valid {
+ _m.TemplateID = new(int64)
+ *_m.TemplateID = value.Int64
+ }
+ case channelmonitor.FieldExtraHeaders:
+ if value, ok := values[i].(*[]byte); !ok {
+ return fmt.Errorf("unexpected type %T for field extra_headers", values[i])
+ } else if value != nil && len(*value) > 0 {
+ if err := json.Unmarshal(*value, &_m.ExtraHeaders); err != nil {
+ return fmt.Errorf("unmarshal field extra_headers: %w", err)
+ }
+ }
+ case channelmonitor.FieldBodyOverrideMode:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field body_override_mode", values[i])
+ } else if value.Valid {
+ _m.BodyOverrideMode = value.String
+ }
+ case channelmonitor.FieldBodyOverride:
+ if value, ok := values[i].(*[]byte); !ok {
+ return fmt.Errorf("unexpected type %T for field body_override", values[i])
+ } else if value != nil && len(*value) > 0 {
+ if err := json.Unmarshal(*value, &_m.BodyOverride); err != nil {
+ return fmt.Errorf("unmarshal field body_override: %w", err)
+ }
+ }
default:
_m.selectValues.Set(columns[i], values[i])
}
@@ -219,6 +270,11 @@ func (_m *ChannelMonitor) QueryDailyRollups() *ChannelMonitorDailyRollupQuery {
return NewChannelMonitorClient(_m.config).QueryDailyRollups(_m)
}
+// QueryRequestTemplate queries the "request_template" edge of the ChannelMonitor entity.
+func (_m *ChannelMonitor) QueryRequestTemplate() *ChannelMonitorRequestTemplateQuery {
+ return NewChannelMonitorClient(_m.config).QueryRequestTemplate(_m)
+}
+
// Update returns a builder for updating this ChannelMonitor.
// Note that you need to call ChannelMonitor.Unwrap() before calling this method if this ChannelMonitor
// was returned from a transaction, and the transaction was committed or rolled back.
@@ -281,6 +337,20 @@ func (_m *ChannelMonitor) String() string {
builder.WriteString(", ")
builder.WriteString("created_by=")
builder.WriteString(fmt.Sprintf("%v", _m.CreatedBy))
+ builder.WriteString(", ")
+ if v := _m.TemplateID; v != nil {
+ builder.WriteString("template_id=")
+ builder.WriteString(fmt.Sprintf("%v", *v))
+ }
+ builder.WriteString(", ")
+ builder.WriteString("extra_headers=")
+ builder.WriteString(fmt.Sprintf("%v", _m.ExtraHeaders))
+ builder.WriteString(", ")
+ builder.WriteString("body_override_mode=")
+ builder.WriteString(_m.BodyOverrideMode)
+ builder.WriteString(", ")
+ builder.WriteString("body_override=")
+ builder.WriteString(fmt.Sprintf("%v", _m.BodyOverride))
builder.WriteByte(')')
return builder.String()
}
diff --git a/backend/ent/channelmonitor/channelmonitor.go b/backend/ent/channelmonitor/channelmonitor.go
index ff6d7105..e5a6bfe7 100644
--- a/backend/ent/channelmonitor/channelmonitor.go
+++ b/backend/ent/channelmonitor/channelmonitor.go
@@ -41,10 +41,20 @@ const (
FieldLastCheckedAt = "last_checked_at"
// FieldCreatedBy holds the string denoting the created_by field in the database.
FieldCreatedBy = "created_by"
+ // FieldTemplateID holds the string denoting the template_id field in the database.
+ FieldTemplateID = "template_id"
+ // FieldExtraHeaders holds the string denoting the extra_headers field in the database.
+ FieldExtraHeaders = "extra_headers"
+ // FieldBodyOverrideMode holds the string denoting the body_override_mode field in the database.
+ FieldBodyOverrideMode = "body_override_mode"
+ // FieldBodyOverride holds the string denoting the body_override field in the database.
+ FieldBodyOverride = "body_override"
// EdgeHistory holds the string denoting the history edge name in mutations.
EdgeHistory = "history"
// EdgeDailyRollups holds the string denoting the daily_rollups edge name in mutations.
EdgeDailyRollups = "daily_rollups"
+ // EdgeRequestTemplate holds the string denoting the request_template edge name in mutations.
+ EdgeRequestTemplate = "request_template"
// Table holds the table name of the channelmonitor in the database.
Table = "channel_monitors"
// HistoryTable is the table that holds the history relation/edge.
@@ -61,6 +71,13 @@ const (
DailyRollupsInverseTable = "channel_monitor_daily_rollups"
// DailyRollupsColumn is the table column denoting the daily_rollups relation/edge.
DailyRollupsColumn = "monitor_id"
+ // RequestTemplateTable is the table that holds the request_template relation/edge.
+ RequestTemplateTable = "channel_monitors"
+ // RequestTemplateInverseTable is the table name for the ChannelMonitorRequestTemplate entity.
+ // It exists in this package in order to avoid circular dependency with the "channelmonitorrequesttemplate" package.
+ RequestTemplateInverseTable = "channel_monitor_request_templates"
+ // RequestTemplateColumn is the table column denoting the request_template relation/edge.
+ RequestTemplateColumn = "template_id"
)
// Columns holds all SQL columns for channelmonitor fields.
@@ -79,6 +96,10 @@ var Columns = []string{
FieldIntervalSeconds,
FieldLastCheckedAt,
FieldCreatedBy,
+ FieldTemplateID,
+ FieldExtraHeaders,
+ FieldBodyOverrideMode,
+ FieldBodyOverride,
}
// ValidColumn reports if the column name is valid (part of the table columns).
@@ -116,6 +137,12 @@ var (
DefaultEnabled bool
// IntervalSecondsValidator is a validator for the "interval_seconds" field. It is called by the builders before save.
IntervalSecondsValidator func(int) error
+ // DefaultExtraHeaders holds the default value on creation for the "extra_headers" field.
+ DefaultExtraHeaders map[string]string
+ // DefaultBodyOverrideMode holds the default value on creation for the "body_override_mode" field.
+ DefaultBodyOverrideMode string
+ // BodyOverrideModeValidator is a validator for the "body_override_mode" field. It is called by the builders before save.
+ BodyOverrideModeValidator func(string) error
)
// Provider defines the type for the "provider" enum field.
@@ -210,6 +237,16 @@ func ByCreatedBy(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCreatedBy, opts...).ToFunc()
}
+// ByTemplateID orders the results by the template_id field.
+func ByTemplateID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldTemplateID, opts...).ToFunc()
+}
+
+// ByBodyOverrideMode orders the results by the body_override_mode field.
+func ByBodyOverrideMode(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldBodyOverrideMode, opts...).ToFunc()
+}
+
// ByHistoryCount orders the results by history count.
func ByHistoryCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
@@ -237,6 +274,13 @@ func ByDailyRollups(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
sqlgraph.OrderByNeighborTerms(s, newDailyRollupsStep(), append([]sql.OrderTerm{term}, terms...)...)
}
}
+
+// ByRequestTemplateField orders the results by request_template field.
+func ByRequestTemplateField(field string, opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newRequestTemplateStep(), sql.OrderByField(field, opts...))
+ }
+}
func newHistoryStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
@@ -251,3 +295,10 @@ func newDailyRollupsStep() *sqlgraph.Step {
sqlgraph.Edge(sqlgraph.O2M, false, DailyRollupsTable, DailyRollupsColumn),
)
}
+func newRequestTemplateStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(RequestTemplateInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, false, RequestTemplateTable, RequestTemplateColumn),
+ )
+}
diff --git a/backend/ent/channelmonitor/where.go b/backend/ent/channelmonitor/where.go
index abb8484d..755d83a3 100644
--- a/backend/ent/channelmonitor/where.go
+++ b/backend/ent/channelmonitor/where.go
@@ -110,6 +110,16 @@ func CreatedBy(v int64) predicate.ChannelMonitor {
return predicate.ChannelMonitor(sql.FieldEQ(FieldCreatedBy, v))
}
+// TemplateID applies equality check predicate on the "template_id" field. It's identical to TemplateIDEQ.
+func TemplateID(v int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldTemplateID, v))
+}
+
+// BodyOverrideMode applies equality check predicate on the "body_override_mode" field. It's identical to BodyOverrideModeEQ.
+func BodyOverrideMode(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldBodyOverrideMode, v))
+}
+
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.ChannelMonitor {
return predicate.ChannelMonitor(sql.FieldEQ(FieldCreatedAt, v))
@@ -685,6 +695,111 @@ func CreatedByLTE(v int64) predicate.ChannelMonitor {
return predicate.ChannelMonitor(sql.FieldLTE(FieldCreatedBy, v))
}
+// TemplateIDEQ applies the EQ predicate on the "template_id" field.
+func TemplateIDEQ(v int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldTemplateID, v))
+}
+
+// TemplateIDNEQ applies the NEQ predicate on the "template_id" field.
+func TemplateIDNEQ(v int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNEQ(FieldTemplateID, v))
+}
+
+// TemplateIDIn applies the In predicate on the "template_id" field.
+func TemplateIDIn(vs ...int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldIn(FieldTemplateID, vs...))
+}
+
+// TemplateIDNotIn applies the NotIn predicate on the "template_id" field.
+func TemplateIDNotIn(vs ...int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNotIn(FieldTemplateID, vs...))
+}
+
+// TemplateIDIsNil applies the IsNil predicate on the "template_id" field.
+func TemplateIDIsNil() predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldIsNull(FieldTemplateID))
+}
+
+// TemplateIDNotNil applies the NotNil predicate on the "template_id" field.
+func TemplateIDNotNil() predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNotNull(FieldTemplateID))
+}
+
+// BodyOverrideModeEQ applies the EQ predicate on the "body_override_mode" field.
+func BodyOverrideModeEQ(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeNEQ applies the NEQ predicate on the "body_override_mode" field.
+func BodyOverrideModeNEQ(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNEQ(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeIn applies the In predicate on the "body_override_mode" field.
+func BodyOverrideModeIn(vs ...string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldIn(FieldBodyOverrideMode, vs...))
+}
+
+// BodyOverrideModeNotIn applies the NotIn predicate on the "body_override_mode" field.
+func BodyOverrideModeNotIn(vs ...string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNotIn(FieldBodyOverrideMode, vs...))
+}
+
+// BodyOverrideModeGT applies the GT predicate on the "body_override_mode" field.
+func BodyOverrideModeGT(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGT(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeGTE applies the GTE predicate on the "body_override_mode" field.
+func BodyOverrideModeGTE(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGTE(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeLT applies the LT predicate on the "body_override_mode" field.
+func BodyOverrideModeLT(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLT(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeLTE applies the LTE predicate on the "body_override_mode" field.
+func BodyOverrideModeLTE(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLTE(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeContains applies the Contains predicate on the "body_override_mode" field.
+func BodyOverrideModeContains(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldContains(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeHasPrefix applies the HasPrefix predicate on the "body_override_mode" field.
+func BodyOverrideModeHasPrefix(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldHasPrefix(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeHasSuffix applies the HasSuffix predicate on the "body_override_mode" field.
+func BodyOverrideModeHasSuffix(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldHasSuffix(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeEqualFold applies the EqualFold predicate on the "body_override_mode" field.
+func BodyOverrideModeEqualFold(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEqualFold(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeContainsFold applies the ContainsFold predicate on the "body_override_mode" field.
+func BodyOverrideModeContainsFold(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldContainsFold(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideIsNil applies the IsNil predicate on the "body_override" field.
+func BodyOverrideIsNil() predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldIsNull(FieldBodyOverride))
+}
+
+// BodyOverrideNotNil applies the NotNil predicate on the "body_override" field.
+func BodyOverrideNotNil() predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNotNull(FieldBodyOverride))
+}
+
// HasHistory applies the HasEdge predicate on the "history" edge.
func HasHistory() predicate.ChannelMonitor {
return predicate.ChannelMonitor(func(s *sql.Selector) {
@@ -731,6 +846,29 @@ func HasDailyRollupsWith(preds ...predicate.ChannelMonitorDailyRollup) predicate
})
}
+// HasRequestTemplate applies the HasEdge predicate on the "request_template" edge.
+func HasRequestTemplate() predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, false, RequestTemplateTable, RequestTemplateColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasRequestTemplateWith applies the HasEdge predicate on the "request_template" edge with a given conditions (other predicates).
+func HasRequestTemplateWith(preds ...predicate.ChannelMonitorRequestTemplate) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(func(s *sql.Selector) {
+ step := newRequestTemplateStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
// And groups predicates with the AND operator between them.
func And(predicates ...predicate.ChannelMonitor) predicate.ChannelMonitor {
return predicate.ChannelMonitor(sql.AndPredicates(predicates...))
diff --git a/backend/ent/channelmonitor_create.go b/backend/ent/channelmonitor_create.go
index 30a7b40d..2f70c300 100644
--- a/backend/ent/channelmonitor_create.go
+++ b/backend/ent/channelmonitor_create.go
@@ -14,6 +14,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/channelmonitor"
"github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
"github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate"
)
// ChannelMonitorCreate is the builder for creating a ChannelMonitor entity.
@@ -142,6 +143,46 @@ func (_c *ChannelMonitorCreate) SetCreatedBy(v int64) *ChannelMonitorCreate {
return _c
}
+// SetTemplateID sets the "template_id" field.
+func (_c *ChannelMonitorCreate) SetTemplateID(v int64) *ChannelMonitorCreate {
+ _c.mutation.SetTemplateID(v)
+ return _c
+}
+
+// SetNillableTemplateID sets the "template_id" field if the given value is not nil.
+func (_c *ChannelMonitorCreate) SetNillableTemplateID(v *int64) *ChannelMonitorCreate {
+ if v != nil {
+ _c.SetTemplateID(*v)
+ }
+ return _c
+}
+
+// SetExtraHeaders sets the "extra_headers" field.
+func (_c *ChannelMonitorCreate) SetExtraHeaders(v map[string]string) *ChannelMonitorCreate {
+ _c.mutation.SetExtraHeaders(v)
+ return _c
+}
+
+// SetBodyOverrideMode sets the "body_override_mode" field.
+func (_c *ChannelMonitorCreate) SetBodyOverrideMode(v string) *ChannelMonitorCreate {
+ _c.mutation.SetBodyOverrideMode(v)
+ return _c
+}
+
+// SetNillableBodyOverrideMode sets the "body_override_mode" field if the given value is not nil.
+func (_c *ChannelMonitorCreate) SetNillableBodyOverrideMode(v *string) *ChannelMonitorCreate {
+ if v != nil {
+ _c.SetBodyOverrideMode(*v)
+ }
+ return _c
+}
+
+// SetBodyOverride sets the "body_override" field.
+func (_c *ChannelMonitorCreate) SetBodyOverride(v map[string]interface{}) *ChannelMonitorCreate {
+ _c.mutation.SetBodyOverride(v)
+ return _c
+}
+
// AddHistoryIDs adds the "history" edge to the ChannelMonitorHistory entity by IDs.
func (_c *ChannelMonitorCreate) AddHistoryIDs(ids ...int64) *ChannelMonitorCreate {
_c.mutation.AddHistoryIDs(ids...)
@@ -172,6 +213,25 @@ func (_c *ChannelMonitorCreate) AddDailyRollups(v ...*ChannelMonitorDailyRollup)
return _c.AddDailyRollupIDs(ids...)
}
+// SetRequestTemplateID sets the "request_template" edge to the ChannelMonitorRequestTemplate entity by ID.
+func (_c *ChannelMonitorCreate) SetRequestTemplateID(id int64) *ChannelMonitorCreate {
+ _c.mutation.SetRequestTemplateID(id)
+ return _c
+}
+
+// SetNillableRequestTemplateID sets the "request_template" edge to the ChannelMonitorRequestTemplate entity by ID if the given value is not nil.
+func (_c *ChannelMonitorCreate) SetNillableRequestTemplateID(id *int64) *ChannelMonitorCreate {
+ if id != nil {
+ _c = _c.SetRequestTemplateID(*id)
+ }
+ return _c
+}
+
+// SetRequestTemplate sets the "request_template" edge to the ChannelMonitorRequestTemplate entity.
+func (_c *ChannelMonitorCreate) SetRequestTemplate(v *ChannelMonitorRequestTemplate) *ChannelMonitorCreate {
+ return _c.SetRequestTemplateID(v.ID)
+}
+
// Mutation returns the ChannelMonitorMutation object of the builder.
func (_c *ChannelMonitorCreate) Mutation() *ChannelMonitorMutation {
return _c.mutation
@@ -227,6 +287,14 @@ func (_c *ChannelMonitorCreate) defaults() {
v := channelmonitor.DefaultEnabled
_c.mutation.SetEnabled(v)
}
+ if _, ok := _c.mutation.ExtraHeaders(); !ok {
+ v := channelmonitor.DefaultExtraHeaders
+ _c.mutation.SetExtraHeaders(v)
+ }
+ if _, ok := _c.mutation.BodyOverrideMode(); !ok {
+ v := channelmonitor.DefaultBodyOverrideMode
+ _c.mutation.SetBodyOverrideMode(v)
+ }
}
// check runs all checks and user-defined validators on the builder.
@@ -299,6 +367,17 @@ func (_c *ChannelMonitorCreate) check() error {
if _, ok := _c.mutation.CreatedBy(); !ok {
return &ValidationError{Name: "created_by", err: errors.New(`ent: missing required field "ChannelMonitor.created_by"`)}
}
+ if _, ok := _c.mutation.ExtraHeaders(); !ok {
+ return &ValidationError{Name: "extra_headers", err: errors.New(`ent: missing required field "ChannelMonitor.extra_headers"`)}
+ }
+ if _, ok := _c.mutation.BodyOverrideMode(); !ok {
+ return &ValidationError{Name: "body_override_mode", err: errors.New(`ent: missing required field "ChannelMonitor.body_override_mode"`)}
+ }
+ if v, ok := _c.mutation.BodyOverrideMode(); ok {
+ if err := channelmonitor.BodyOverrideModeValidator(v); err != nil {
+ return &ValidationError{Name: "body_override_mode", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.body_override_mode": %w`, err)}
+ }
+ }
return nil
}
@@ -378,6 +457,18 @@ func (_c *ChannelMonitorCreate) createSpec() (*ChannelMonitor, *sqlgraph.CreateS
_spec.SetField(channelmonitor.FieldCreatedBy, field.TypeInt64, value)
_node.CreatedBy = value
}
+ if value, ok := _c.mutation.ExtraHeaders(); ok {
+ _spec.SetField(channelmonitor.FieldExtraHeaders, field.TypeJSON, value)
+ _node.ExtraHeaders = value
+ }
+ if value, ok := _c.mutation.BodyOverrideMode(); ok {
+ _spec.SetField(channelmonitor.FieldBodyOverrideMode, field.TypeString, value)
+ _node.BodyOverrideMode = value
+ }
+ if value, ok := _c.mutation.BodyOverride(); ok {
+ _spec.SetField(channelmonitor.FieldBodyOverride, field.TypeJSON, value)
+ _node.BodyOverride = value
+ }
if nodes := _c.mutation.HistoryIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -410,6 +501,23 @@ func (_c *ChannelMonitorCreate) createSpec() (*ChannelMonitor, *sqlgraph.CreateS
}
_spec.Edges = append(_spec.Edges, edge)
}
+ if nodes := _c.mutation.RequestTemplateIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: false,
+ Table: channelmonitor.RequestTemplateTable,
+ Columns: []string{channelmonitor.RequestTemplateColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitorrequesttemplate.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _node.TemplateID = &nodes[0]
+ _spec.Edges = append(_spec.Edges, edge)
+ }
return _node, _spec
}
@@ -630,6 +738,66 @@ func (u *ChannelMonitorUpsert) AddCreatedBy(v int64) *ChannelMonitorUpsert {
return u
}
+// SetTemplateID sets the "template_id" field.
+func (u *ChannelMonitorUpsert) SetTemplateID(v int64) *ChannelMonitorUpsert {
+ u.Set(channelmonitor.FieldTemplateID, v)
+ return u
+}
+
+// UpdateTemplateID sets the "template_id" field to the value that was provided on create.
+func (u *ChannelMonitorUpsert) UpdateTemplateID() *ChannelMonitorUpsert {
+ u.SetExcluded(channelmonitor.FieldTemplateID)
+ return u
+}
+
+// ClearTemplateID clears the value of the "template_id" field.
+func (u *ChannelMonitorUpsert) ClearTemplateID() *ChannelMonitorUpsert {
+ u.SetNull(channelmonitor.FieldTemplateID)
+ return u
+}
+
+// SetExtraHeaders sets the "extra_headers" field.
+func (u *ChannelMonitorUpsert) SetExtraHeaders(v map[string]string) *ChannelMonitorUpsert {
+ u.Set(channelmonitor.FieldExtraHeaders, v)
+ return u
+}
+
+// UpdateExtraHeaders sets the "extra_headers" field to the value that was provided on create.
+func (u *ChannelMonitorUpsert) UpdateExtraHeaders() *ChannelMonitorUpsert {
+ u.SetExcluded(channelmonitor.FieldExtraHeaders)
+ return u
+}
+
+// SetBodyOverrideMode sets the "body_override_mode" field.
+func (u *ChannelMonitorUpsert) SetBodyOverrideMode(v string) *ChannelMonitorUpsert {
+ u.Set(channelmonitor.FieldBodyOverrideMode, v)
+ return u
+}
+
+// UpdateBodyOverrideMode sets the "body_override_mode" field to the value that was provided on create.
+func (u *ChannelMonitorUpsert) UpdateBodyOverrideMode() *ChannelMonitorUpsert {
+ u.SetExcluded(channelmonitor.FieldBodyOverrideMode)
+ return u
+}
+
+// SetBodyOverride sets the "body_override" field.
+func (u *ChannelMonitorUpsert) SetBodyOverride(v map[string]interface{}) *ChannelMonitorUpsert {
+ u.Set(channelmonitor.FieldBodyOverride, v)
+ return u
+}
+
+// UpdateBodyOverride sets the "body_override" field to the value that was provided on create.
+func (u *ChannelMonitorUpsert) UpdateBodyOverride() *ChannelMonitorUpsert {
+ u.SetExcluded(channelmonitor.FieldBodyOverride)
+ return u
+}
+
+// ClearBodyOverride clears the value of the "body_override" field.
+func (u *ChannelMonitorUpsert) ClearBodyOverride() *ChannelMonitorUpsert {
+ u.SetNull(channelmonitor.FieldBodyOverride)
+ return u
+}
+
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
//
@@ -871,6 +1039,76 @@ func (u *ChannelMonitorUpsertOne) UpdateCreatedBy() *ChannelMonitorUpsertOne {
})
}
+// SetTemplateID sets the "template_id" field.
+func (u *ChannelMonitorUpsertOne) SetTemplateID(v int64) *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetTemplateID(v)
+ })
+}
+
+// UpdateTemplateID sets the "template_id" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertOne) UpdateTemplateID() *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateTemplateID()
+ })
+}
+
+// ClearTemplateID clears the value of the "template_id" field.
+func (u *ChannelMonitorUpsertOne) ClearTemplateID() *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.ClearTemplateID()
+ })
+}
+
+// SetExtraHeaders sets the "extra_headers" field.
+func (u *ChannelMonitorUpsertOne) SetExtraHeaders(v map[string]string) *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetExtraHeaders(v)
+ })
+}
+
+// UpdateExtraHeaders sets the "extra_headers" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertOne) UpdateExtraHeaders() *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateExtraHeaders()
+ })
+}
+
+// SetBodyOverrideMode sets the "body_override_mode" field.
+func (u *ChannelMonitorUpsertOne) SetBodyOverrideMode(v string) *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetBodyOverrideMode(v)
+ })
+}
+
+// UpdateBodyOverrideMode sets the "body_override_mode" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertOne) UpdateBodyOverrideMode() *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateBodyOverrideMode()
+ })
+}
+
+// SetBodyOverride sets the "body_override" field.
+func (u *ChannelMonitorUpsertOne) SetBodyOverride(v map[string]interface{}) *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetBodyOverride(v)
+ })
+}
+
+// UpdateBodyOverride sets the "body_override" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertOne) UpdateBodyOverride() *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateBodyOverride()
+ })
+}
+
+// ClearBodyOverride clears the value of the "body_override" field.
+func (u *ChannelMonitorUpsertOne) ClearBodyOverride() *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.ClearBodyOverride()
+ })
+}
+
// Exec executes the query.
func (u *ChannelMonitorUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 {
@@ -1278,6 +1516,76 @@ func (u *ChannelMonitorUpsertBulk) UpdateCreatedBy() *ChannelMonitorUpsertBulk {
})
}
+// SetTemplateID sets the "template_id" field.
+func (u *ChannelMonitorUpsertBulk) SetTemplateID(v int64) *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetTemplateID(v)
+ })
+}
+
+// UpdateTemplateID sets the "template_id" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertBulk) UpdateTemplateID() *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateTemplateID()
+ })
+}
+
+// ClearTemplateID clears the value of the "template_id" field.
+func (u *ChannelMonitorUpsertBulk) ClearTemplateID() *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.ClearTemplateID()
+ })
+}
+
+// SetExtraHeaders sets the "extra_headers" field.
+func (u *ChannelMonitorUpsertBulk) SetExtraHeaders(v map[string]string) *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetExtraHeaders(v)
+ })
+}
+
+// UpdateExtraHeaders sets the "extra_headers" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertBulk) UpdateExtraHeaders() *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateExtraHeaders()
+ })
+}
+
+// SetBodyOverrideMode sets the "body_override_mode" field.
+func (u *ChannelMonitorUpsertBulk) SetBodyOverrideMode(v string) *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetBodyOverrideMode(v)
+ })
+}
+
+// UpdateBodyOverrideMode sets the "body_override_mode" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertBulk) UpdateBodyOverrideMode() *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateBodyOverrideMode()
+ })
+}
+
+// SetBodyOverride sets the "body_override" field.
+func (u *ChannelMonitorUpsertBulk) SetBodyOverride(v map[string]interface{}) *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetBodyOverride(v)
+ })
+}
+
+// UpdateBodyOverride sets the "body_override" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertBulk) UpdateBodyOverride() *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateBodyOverride()
+ })
+}
+
+// ClearBodyOverride clears the value of the "body_override" field.
+func (u *ChannelMonitorUpsertBulk) ClearBodyOverride() *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.ClearBodyOverride()
+ })
+}
+
// Exec executes the query.
func (u *ChannelMonitorUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil {
diff --git a/backend/ent/channelmonitor_query.go b/backend/ent/channelmonitor_query.go
index 2ebd95bb..b6722e78 100644
--- a/backend/ent/channelmonitor_query.go
+++ b/backend/ent/channelmonitor_query.go
@@ -16,19 +16,21 @@ import (
"github.com/Wei-Shaw/sub2api/ent/channelmonitor"
"github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
"github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate"
"github.com/Wei-Shaw/sub2api/ent/predicate"
)
// ChannelMonitorQuery is the builder for querying ChannelMonitor entities.
type ChannelMonitorQuery struct {
config
- ctx *QueryContext
- order []channelmonitor.OrderOption
- inters []Interceptor
- predicates []predicate.ChannelMonitor
- withHistory *ChannelMonitorHistoryQuery
- withDailyRollups *ChannelMonitorDailyRollupQuery
- modifiers []func(*sql.Selector)
+ ctx *QueryContext
+ order []channelmonitor.OrderOption
+ inters []Interceptor
+ predicates []predicate.ChannelMonitor
+ withHistory *ChannelMonitorHistoryQuery
+ withDailyRollups *ChannelMonitorDailyRollupQuery
+ withRequestTemplate *ChannelMonitorRequestTemplateQuery
+ modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
@@ -109,6 +111,28 @@ func (_q *ChannelMonitorQuery) QueryDailyRollups() *ChannelMonitorDailyRollupQue
return query
}
+// QueryRequestTemplate chains the current query on the "request_template" edge.
+func (_q *ChannelMonitorQuery) QueryRequestTemplate() *ChannelMonitorRequestTemplateQuery {
+ query := (&ChannelMonitorRequestTemplateClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(channelmonitor.Table, channelmonitor.FieldID, selector),
+ sqlgraph.To(channelmonitorrequesttemplate.Table, channelmonitorrequesttemplate.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, false, channelmonitor.RequestTemplateTable, channelmonitor.RequestTemplateColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
// First returns the first ChannelMonitor entity from the query.
// Returns a *NotFoundError when no ChannelMonitor was found.
func (_q *ChannelMonitorQuery) First(ctx context.Context) (*ChannelMonitor, error) {
@@ -296,13 +320,14 @@ func (_q *ChannelMonitorQuery) Clone() *ChannelMonitorQuery {
return nil
}
return &ChannelMonitorQuery{
- config: _q.config,
- ctx: _q.ctx.Clone(),
- order: append([]channelmonitor.OrderOption{}, _q.order...),
- inters: append([]Interceptor{}, _q.inters...),
- predicates: append([]predicate.ChannelMonitor{}, _q.predicates...),
- withHistory: _q.withHistory.Clone(),
- withDailyRollups: _q.withDailyRollups.Clone(),
+ config: _q.config,
+ ctx: _q.ctx.Clone(),
+ order: append([]channelmonitor.OrderOption{}, _q.order...),
+ inters: append([]Interceptor{}, _q.inters...),
+ predicates: append([]predicate.ChannelMonitor{}, _q.predicates...),
+ withHistory: _q.withHistory.Clone(),
+ withDailyRollups: _q.withDailyRollups.Clone(),
+ withRequestTemplate: _q.withRequestTemplate.Clone(),
// clone intermediate query.
sql: _q.sql.Clone(),
path: _q.path,
@@ -331,6 +356,17 @@ func (_q *ChannelMonitorQuery) WithDailyRollups(opts ...func(*ChannelMonitorDail
return _q
}
+// WithRequestTemplate tells the query-builder to eager-load the nodes that are connected to
+// the "request_template" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *ChannelMonitorQuery) WithRequestTemplate(opts ...func(*ChannelMonitorRequestTemplateQuery)) *ChannelMonitorQuery {
+ query := (&ChannelMonitorRequestTemplateClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withRequestTemplate = query
+ return _q
+}
+
// GroupBy is used to group vertices by one or more fields/columns.
// It is often used with aggregate functions, like: count, max, mean, min, sum.
//
@@ -409,9 +445,10 @@ func (_q *ChannelMonitorQuery) sqlAll(ctx context.Context, hooks ...queryHook) (
var (
nodes = []*ChannelMonitor{}
_spec = _q.querySpec()
- loadedTypes = [2]bool{
+ loadedTypes = [3]bool{
_q.withHistory != nil,
_q.withDailyRollups != nil,
+ _q.withRequestTemplate != nil,
}
)
_spec.ScanValues = func(columns []string) ([]any, error) {
@@ -451,6 +488,12 @@ func (_q *ChannelMonitorQuery) sqlAll(ctx context.Context, hooks ...queryHook) (
return nil, err
}
}
+ if query := _q.withRequestTemplate; query != nil {
+ if err := _q.loadRequestTemplate(ctx, query, nodes, nil,
+ func(n *ChannelMonitor, e *ChannelMonitorRequestTemplate) { n.Edges.RequestTemplate = e }); err != nil {
+ return nil, err
+ }
+ }
return nodes, nil
}
@@ -514,6 +557,38 @@ func (_q *ChannelMonitorQuery) loadDailyRollups(ctx context.Context, query *Chan
}
return nil
}
+func (_q *ChannelMonitorQuery) loadRequestTemplate(ctx context.Context, query *ChannelMonitorRequestTemplateQuery, nodes []*ChannelMonitor, init func(*ChannelMonitor), assign func(*ChannelMonitor, *ChannelMonitorRequestTemplate)) error {
+ ids := make([]int64, 0, len(nodes))
+ nodeids := make(map[int64][]*ChannelMonitor)
+ for i := range nodes {
+ if nodes[i].TemplateID == nil {
+ continue
+ }
+ fk := *nodes[i].TemplateID
+ if _, ok := nodeids[fk]; !ok {
+ ids = append(ids, fk)
+ }
+ nodeids[fk] = append(nodeids[fk], nodes[i])
+ }
+ if len(ids) == 0 {
+ return nil
+ }
+ query.Where(channelmonitorrequesttemplate.IDIn(ids...))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ nodes, ok := nodeids[n.ID]
+ if !ok {
+ return fmt.Errorf(`unexpected foreign-key "template_id" returned %v`, n.ID)
+ }
+ for i := range nodes {
+ assign(nodes[i], n)
+ }
+ }
+ return nil
+}
func (_q *ChannelMonitorQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()
@@ -543,6 +618,9 @@ func (_q *ChannelMonitorQuery) querySpec() *sqlgraph.QuerySpec {
_spec.Node.Columns = append(_spec.Node.Columns, fields[i])
}
}
+ if _q.withRequestTemplate != nil {
+ _spec.Node.AddColumnOnce(channelmonitor.FieldTemplateID)
+ }
}
if ps := _q.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
diff --git a/backend/ent/channelmonitor_update.go b/backend/ent/channelmonitor_update.go
index 7ba4e449..4bbcd564 100644
--- a/backend/ent/channelmonitor_update.go
+++ b/backend/ent/channelmonitor_update.go
@@ -15,6 +15,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/channelmonitor"
"github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
"github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate"
"github.com/Wei-Shaw/sub2api/ent/predicate"
)
@@ -215,6 +216,58 @@ func (_u *ChannelMonitorUpdate) AddCreatedBy(v int64) *ChannelMonitorUpdate {
return _u
}
+// SetTemplateID sets the "template_id" field.
+func (_u *ChannelMonitorUpdate) SetTemplateID(v int64) *ChannelMonitorUpdate {
+ _u.mutation.SetTemplateID(v)
+ return _u
+}
+
+// SetNillableTemplateID sets the "template_id" field if the given value is not nil.
+func (_u *ChannelMonitorUpdate) SetNillableTemplateID(v *int64) *ChannelMonitorUpdate {
+ if v != nil {
+ _u.SetTemplateID(*v)
+ }
+ return _u
+}
+
+// ClearTemplateID clears the value of the "template_id" field.
+func (_u *ChannelMonitorUpdate) ClearTemplateID() *ChannelMonitorUpdate {
+ _u.mutation.ClearTemplateID()
+ return _u
+}
+
+// SetExtraHeaders sets the "extra_headers" field.
+func (_u *ChannelMonitorUpdate) SetExtraHeaders(v map[string]string) *ChannelMonitorUpdate {
+ _u.mutation.SetExtraHeaders(v)
+ return _u
+}
+
+// SetBodyOverrideMode sets the "body_override_mode" field.
+func (_u *ChannelMonitorUpdate) SetBodyOverrideMode(v string) *ChannelMonitorUpdate {
+ _u.mutation.SetBodyOverrideMode(v)
+ return _u
+}
+
+// SetNillableBodyOverrideMode sets the "body_override_mode" field if the given value is not nil.
+func (_u *ChannelMonitorUpdate) SetNillableBodyOverrideMode(v *string) *ChannelMonitorUpdate {
+ if v != nil {
+ _u.SetBodyOverrideMode(*v)
+ }
+ return _u
+}
+
+// SetBodyOverride sets the "body_override" field.
+func (_u *ChannelMonitorUpdate) SetBodyOverride(v map[string]interface{}) *ChannelMonitorUpdate {
+ _u.mutation.SetBodyOverride(v)
+ return _u
+}
+
+// ClearBodyOverride clears the value of the "body_override" field.
+func (_u *ChannelMonitorUpdate) ClearBodyOverride() *ChannelMonitorUpdate {
+ _u.mutation.ClearBodyOverride()
+ return _u
+}
+
// AddHistoryIDs adds the "history" edge to the ChannelMonitorHistory entity by IDs.
func (_u *ChannelMonitorUpdate) AddHistoryIDs(ids ...int64) *ChannelMonitorUpdate {
_u.mutation.AddHistoryIDs(ids...)
@@ -245,6 +298,25 @@ func (_u *ChannelMonitorUpdate) AddDailyRollups(v ...*ChannelMonitorDailyRollup)
return _u.AddDailyRollupIDs(ids...)
}
+// SetRequestTemplateID sets the "request_template" edge to the ChannelMonitorRequestTemplate entity by ID.
+func (_u *ChannelMonitorUpdate) SetRequestTemplateID(id int64) *ChannelMonitorUpdate {
+ _u.mutation.SetRequestTemplateID(id)
+ return _u
+}
+
+// SetNillableRequestTemplateID sets the "request_template" edge to the ChannelMonitorRequestTemplate entity by ID if the given value is not nil.
+func (_u *ChannelMonitorUpdate) SetNillableRequestTemplateID(id *int64) *ChannelMonitorUpdate {
+ if id != nil {
+ _u = _u.SetRequestTemplateID(*id)
+ }
+ return _u
+}
+
+// SetRequestTemplate sets the "request_template" edge to the ChannelMonitorRequestTemplate entity.
+func (_u *ChannelMonitorUpdate) SetRequestTemplate(v *ChannelMonitorRequestTemplate) *ChannelMonitorUpdate {
+ return _u.SetRequestTemplateID(v.ID)
+}
+
// Mutation returns the ChannelMonitorMutation object of the builder.
func (_u *ChannelMonitorUpdate) Mutation() *ChannelMonitorMutation {
return _u.mutation
@@ -292,6 +364,12 @@ func (_u *ChannelMonitorUpdate) RemoveDailyRollups(v ...*ChannelMonitorDailyRoll
return _u.RemoveDailyRollupIDs(ids...)
}
+// ClearRequestTemplate clears the "request_template" edge to the ChannelMonitorRequestTemplate entity.
+func (_u *ChannelMonitorUpdate) ClearRequestTemplate() *ChannelMonitorUpdate {
+ _u.mutation.ClearRequestTemplate()
+ return _u
+}
+
// Save executes the query and returns the number of nodes affected by the update operation.
func (_u *ChannelMonitorUpdate) Save(ctx context.Context) (int, error) {
_u.defaults()
@@ -365,6 +443,11 @@ func (_u *ChannelMonitorUpdate) check() error {
return &ValidationError{Name: "interval_seconds", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.interval_seconds": %w`, err)}
}
}
+ if v, ok := _u.mutation.BodyOverrideMode(); ok {
+ if err := channelmonitor.BodyOverrideModeValidator(v); err != nil {
+ return &ValidationError{Name: "body_override_mode", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.body_override_mode": %w`, err)}
+ }
+ }
return nil
}
@@ -433,6 +516,18 @@ func (_u *ChannelMonitorUpdate) sqlSave(ctx context.Context) (_node int, err err
if value, ok := _u.mutation.AddedCreatedBy(); ok {
_spec.AddField(channelmonitor.FieldCreatedBy, field.TypeInt64, value)
}
+ if value, ok := _u.mutation.ExtraHeaders(); ok {
+ _spec.SetField(channelmonitor.FieldExtraHeaders, field.TypeJSON, value)
+ }
+ if value, ok := _u.mutation.BodyOverrideMode(); ok {
+ _spec.SetField(channelmonitor.FieldBodyOverrideMode, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.BodyOverride(); ok {
+ _spec.SetField(channelmonitor.FieldBodyOverride, field.TypeJSON, value)
+ }
+ if _u.mutation.BodyOverrideCleared() {
+ _spec.ClearField(channelmonitor.FieldBodyOverride, field.TypeJSON)
+ }
if _u.mutation.HistoryCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -523,6 +618,35 @@ func (_u *ChannelMonitorUpdate) sqlSave(ctx context.Context) (_node int, err err
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
+ if _u.mutation.RequestTemplateCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: false,
+ Table: channelmonitor.RequestTemplateTable,
+ Columns: []string{channelmonitor.RequestTemplateColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitorrequesttemplate.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RequestTemplateIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: false,
+ Table: channelmonitor.RequestTemplateTable,
+ Columns: []string{channelmonitor.RequestTemplateColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitorrequesttemplate.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{channelmonitor.Label}
@@ -727,6 +851,58 @@ func (_u *ChannelMonitorUpdateOne) AddCreatedBy(v int64) *ChannelMonitorUpdateOn
return _u
}
+// SetTemplateID sets the "template_id" field.
+func (_u *ChannelMonitorUpdateOne) SetTemplateID(v int64) *ChannelMonitorUpdateOne {
+ _u.mutation.SetTemplateID(v)
+ return _u
+}
+
+// SetNillableTemplateID sets the "template_id" field if the given value is not nil.
+func (_u *ChannelMonitorUpdateOne) SetNillableTemplateID(v *int64) *ChannelMonitorUpdateOne {
+ if v != nil {
+ _u.SetTemplateID(*v)
+ }
+ return _u
+}
+
+// ClearTemplateID clears the value of the "template_id" field.
+func (_u *ChannelMonitorUpdateOne) ClearTemplateID() *ChannelMonitorUpdateOne {
+ _u.mutation.ClearTemplateID()
+ return _u
+}
+
+// SetExtraHeaders sets the "extra_headers" field.
+func (_u *ChannelMonitorUpdateOne) SetExtraHeaders(v map[string]string) *ChannelMonitorUpdateOne {
+ _u.mutation.SetExtraHeaders(v)
+ return _u
+}
+
+// SetBodyOverrideMode sets the "body_override_mode" field.
+func (_u *ChannelMonitorUpdateOne) SetBodyOverrideMode(v string) *ChannelMonitorUpdateOne {
+ _u.mutation.SetBodyOverrideMode(v)
+ return _u
+}
+
+// SetNillableBodyOverrideMode sets the "body_override_mode" field if the given value is not nil.
+func (_u *ChannelMonitorUpdateOne) SetNillableBodyOverrideMode(v *string) *ChannelMonitorUpdateOne {
+ if v != nil {
+ _u.SetBodyOverrideMode(*v)
+ }
+ return _u
+}
+
+// SetBodyOverride sets the "body_override" field.
+func (_u *ChannelMonitorUpdateOne) SetBodyOverride(v map[string]interface{}) *ChannelMonitorUpdateOne {
+ _u.mutation.SetBodyOverride(v)
+ return _u
+}
+
+// ClearBodyOverride clears the value of the "body_override" field.
+func (_u *ChannelMonitorUpdateOne) ClearBodyOverride() *ChannelMonitorUpdateOne {
+ _u.mutation.ClearBodyOverride()
+ return _u
+}
+
// AddHistoryIDs adds the "history" edge to the ChannelMonitorHistory entity by IDs.
func (_u *ChannelMonitorUpdateOne) AddHistoryIDs(ids ...int64) *ChannelMonitorUpdateOne {
_u.mutation.AddHistoryIDs(ids...)
@@ -757,6 +933,25 @@ func (_u *ChannelMonitorUpdateOne) AddDailyRollups(v ...*ChannelMonitorDailyRoll
return _u.AddDailyRollupIDs(ids...)
}
+// SetRequestTemplateID sets the "request_template" edge to the ChannelMonitorRequestTemplate entity by ID.
+func (_u *ChannelMonitorUpdateOne) SetRequestTemplateID(id int64) *ChannelMonitorUpdateOne {
+ _u.mutation.SetRequestTemplateID(id)
+ return _u
+}
+
+// SetNillableRequestTemplateID sets the "request_template" edge to the ChannelMonitorRequestTemplate entity by ID if the given value is not nil.
+func (_u *ChannelMonitorUpdateOne) SetNillableRequestTemplateID(id *int64) *ChannelMonitorUpdateOne {
+ if id != nil {
+ _u = _u.SetRequestTemplateID(*id)
+ }
+ return _u
+}
+
+// SetRequestTemplate sets the "request_template" edge to the ChannelMonitorRequestTemplate entity.
+func (_u *ChannelMonitorUpdateOne) SetRequestTemplate(v *ChannelMonitorRequestTemplate) *ChannelMonitorUpdateOne {
+ return _u.SetRequestTemplateID(v.ID)
+}
+
// Mutation returns the ChannelMonitorMutation object of the builder.
func (_u *ChannelMonitorUpdateOne) Mutation() *ChannelMonitorMutation {
return _u.mutation
@@ -804,6 +999,12 @@ func (_u *ChannelMonitorUpdateOne) RemoveDailyRollups(v ...*ChannelMonitorDailyR
return _u.RemoveDailyRollupIDs(ids...)
}
+// ClearRequestTemplate clears the "request_template" edge to the ChannelMonitorRequestTemplate entity.
+func (_u *ChannelMonitorUpdateOne) ClearRequestTemplate() *ChannelMonitorUpdateOne {
+ _u.mutation.ClearRequestTemplate()
+ return _u
+}
+
// Where appends a list predicates to the ChannelMonitorUpdate builder.
func (_u *ChannelMonitorUpdateOne) Where(ps ...predicate.ChannelMonitor) *ChannelMonitorUpdateOne {
_u.mutation.Where(ps...)
@@ -890,6 +1091,11 @@ func (_u *ChannelMonitorUpdateOne) check() error {
return &ValidationError{Name: "interval_seconds", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.interval_seconds": %w`, err)}
}
}
+ if v, ok := _u.mutation.BodyOverrideMode(); ok {
+ if err := channelmonitor.BodyOverrideModeValidator(v); err != nil {
+ return &ValidationError{Name: "body_override_mode", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.body_override_mode": %w`, err)}
+ }
+ }
return nil
}
@@ -975,6 +1181,18 @@ func (_u *ChannelMonitorUpdateOne) sqlSave(ctx context.Context) (_node *ChannelM
if value, ok := _u.mutation.AddedCreatedBy(); ok {
_spec.AddField(channelmonitor.FieldCreatedBy, field.TypeInt64, value)
}
+ if value, ok := _u.mutation.ExtraHeaders(); ok {
+ _spec.SetField(channelmonitor.FieldExtraHeaders, field.TypeJSON, value)
+ }
+ if value, ok := _u.mutation.BodyOverrideMode(); ok {
+ _spec.SetField(channelmonitor.FieldBodyOverrideMode, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.BodyOverride(); ok {
+ _spec.SetField(channelmonitor.FieldBodyOverride, field.TypeJSON, value)
+ }
+ if _u.mutation.BodyOverrideCleared() {
+ _spec.ClearField(channelmonitor.FieldBodyOverride, field.TypeJSON)
+ }
if _u.mutation.HistoryCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -1065,6 +1283,35 @@ func (_u *ChannelMonitorUpdateOne) sqlSave(ctx context.Context) (_node *ChannelM
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
+ if _u.mutation.RequestTemplateCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: false,
+ Table: channelmonitor.RequestTemplateTable,
+ Columns: []string{channelmonitor.RequestTemplateColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitorrequesttemplate.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RequestTemplateIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: false,
+ Table: channelmonitor.RequestTemplateTable,
+ Columns: []string{channelmonitor.RequestTemplateColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitorrequesttemplate.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
_node = &ChannelMonitor{config: _u.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues
diff --git a/backend/ent/channelmonitorrequesttemplate.go b/backend/ent/channelmonitorrequesttemplate.go
new file mode 100644
index 00000000..b8429a4d
--- /dev/null
+++ b/backend/ent/channelmonitorrequesttemplate.go
@@ -0,0 +1,216 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "encoding/json"
+ "fmt"
+ "strings"
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/sql"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate"
+)
+
+// ChannelMonitorRequestTemplate is the model entity for the ChannelMonitorRequestTemplate schema.
+type ChannelMonitorRequestTemplate struct {
+ config `json:"-"`
+ // ID of the ent.
+ ID int64 `json:"id,omitempty"`
+ // CreatedAt holds the value of the "created_at" field.
+ CreatedAt time.Time `json:"created_at,omitempty"`
+ // UpdatedAt holds the value of the "updated_at" field.
+ UpdatedAt time.Time `json:"updated_at,omitempty"`
+ // Name holds the value of the "name" field.
+ Name string `json:"name,omitempty"`
+ // Provider holds the value of the "provider" field.
+ Provider channelmonitorrequesttemplate.Provider `json:"provider,omitempty"`
+ // Description holds the value of the "description" field.
+ Description string `json:"description,omitempty"`
+ // ExtraHeaders holds the value of the "extra_headers" field.
+ ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
+ // BodyOverrideMode holds the value of the "body_override_mode" field.
+ BodyOverrideMode string `json:"body_override_mode,omitempty"`
+ // BodyOverride holds the value of the "body_override" field.
+ BodyOverride map[string]interface{} `json:"body_override,omitempty"`
+ // Edges holds the relations/edges for other nodes in the graph.
+ // The values are being populated by the ChannelMonitorRequestTemplateQuery when eager-loading is set.
+ Edges ChannelMonitorRequestTemplateEdges `json:"edges"`
+ selectValues sql.SelectValues
+}
+
+// ChannelMonitorRequestTemplateEdges holds the relations/edges for other nodes in the graph.
+type ChannelMonitorRequestTemplateEdges struct {
+ // Monitors holds the value of the monitors edge.
+ Monitors []*ChannelMonitor `json:"monitors,omitempty"`
+ // loadedTypes holds the information for reporting if a
+ // type was loaded (or requested) in eager-loading or not.
+ loadedTypes [1]bool
+}
+
+// MonitorsOrErr returns the Monitors value or an error if the edge
+// was not loaded in eager-loading.
+func (e ChannelMonitorRequestTemplateEdges) MonitorsOrErr() ([]*ChannelMonitor, error) {
+ if e.loadedTypes[0] {
+ return e.Monitors, nil
+ }
+ return nil, &NotLoadedError{edge: "monitors"}
+}
+
+// scanValues returns the types for scanning values from sql.Rows.
+func (*ChannelMonitorRequestTemplate) scanValues(columns []string) ([]any, error) {
+ values := make([]any, len(columns))
+ for i := range columns {
+ switch columns[i] {
+ case channelmonitorrequesttemplate.FieldExtraHeaders, channelmonitorrequesttemplate.FieldBodyOverride:
+ values[i] = new([]byte)
+ case channelmonitorrequesttemplate.FieldID:
+ values[i] = new(sql.NullInt64)
+ case channelmonitorrequesttemplate.FieldName, channelmonitorrequesttemplate.FieldProvider, channelmonitorrequesttemplate.FieldDescription, channelmonitorrequesttemplate.FieldBodyOverrideMode:
+ values[i] = new(sql.NullString)
+ case channelmonitorrequesttemplate.FieldCreatedAt, channelmonitorrequesttemplate.FieldUpdatedAt:
+ values[i] = new(sql.NullTime)
+ default:
+ values[i] = new(sql.UnknownType)
+ }
+ }
+ return values, nil
+}
+
+// assignValues assigns the values that were returned from sql.Rows (after scanning)
+// to the ChannelMonitorRequestTemplate fields.
+func (_m *ChannelMonitorRequestTemplate) assignValues(columns []string, values []any) error {
+ if m, n := len(values), len(columns); m < n {
+ return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
+ }
+ for i := range columns {
+ switch columns[i] {
+ case channelmonitorrequesttemplate.FieldID:
+ value, ok := values[i].(*sql.NullInt64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field id", value)
+ }
+ _m.ID = int64(value.Int64)
+ case channelmonitorrequesttemplate.FieldCreatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field created_at", values[i])
+ } else if value.Valid {
+ _m.CreatedAt = value.Time
+ }
+ case channelmonitorrequesttemplate.FieldUpdatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field updated_at", values[i])
+ } else if value.Valid {
+ _m.UpdatedAt = value.Time
+ }
+ case channelmonitorrequesttemplate.FieldName:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field name", values[i])
+ } else if value.Valid {
+ _m.Name = value.String
+ }
+ case channelmonitorrequesttemplate.FieldProvider:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider", values[i])
+ } else if value.Valid {
+ _m.Provider = channelmonitorrequesttemplate.Provider(value.String)
+ }
+ case channelmonitorrequesttemplate.FieldDescription:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field description", values[i])
+ } else if value.Valid {
+ _m.Description = value.String
+ }
+ case channelmonitorrequesttemplate.FieldExtraHeaders:
+ if value, ok := values[i].(*[]byte); !ok {
+ return fmt.Errorf("unexpected type %T for field extra_headers", values[i])
+ } else if value != nil && len(*value) > 0 {
+ if err := json.Unmarshal(*value, &_m.ExtraHeaders); err != nil {
+ return fmt.Errorf("unmarshal field extra_headers: %w", err)
+ }
+ }
+ case channelmonitorrequesttemplate.FieldBodyOverrideMode:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field body_override_mode", values[i])
+ } else if value.Valid {
+ _m.BodyOverrideMode = value.String
+ }
+ case channelmonitorrequesttemplate.FieldBodyOverride:
+ if value, ok := values[i].(*[]byte); !ok {
+ return fmt.Errorf("unexpected type %T for field body_override", values[i])
+ } else if value != nil && len(*value) > 0 {
+ if err := json.Unmarshal(*value, &_m.BodyOverride); err != nil {
+ return fmt.Errorf("unmarshal field body_override: %w", err)
+ }
+ }
+ default:
+ _m.selectValues.Set(columns[i], values[i])
+ }
+ }
+ return nil
+}
+
+// Value returns the ent.Value that was dynamically selected and assigned to the ChannelMonitorRequestTemplate.
+// This includes values selected through modifiers, order, etc.
+func (_m *ChannelMonitorRequestTemplate) Value(name string) (ent.Value, error) {
+ return _m.selectValues.Get(name)
+}
+
+// QueryMonitors queries the "monitors" edge of the ChannelMonitorRequestTemplate entity.
+func (_m *ChannelMonitorRequestTemplate) QueryMonitors() *ChannelMonitorQuery {
+ return NewChannelMonitorRequestTemplateClient(_m.config).QueryMonitors(_m)
+}
+
+// Update returns a builder for updating this ChannelMonitorRequestTemplate.
+// Note that you need to call ChannelMonitorRequestTemplate.Unwrap() before calling this method if this ChannelMonitorRequestTemplate
+// was returned from a transaction, and the transaction was committed or rolled back.
+func (_m *ChannelMonitorRequestTemplate) Update() *ChannelMonitorRequestTemplateUpdateOne {
+ return NewChannelMonitorRequestTemplateClient(_m.config).UpdateOne(_m)
+}
+
+// Unwrap unwraps the ChannelMonitorRequestTemplate entity that was returned from a transaction after it was closed,
+// so that all future queries will be executed through the driver which created the transaction.
+func (_m *ChannelMonitorRequestTemplate) Unwrap() *ChannelMonitorRequestTemplate {
+ _tx, ok := _m.config.driver.(*txDriver)
+ if !ok {
+ panic("ent: ChannelMonitorRequestTemplate is not a transactional entity")
+ }
+ _m.config.driver = _tx.drv
+ return _m
+}
+
+// String implements the fmt.Stringer.
+func (_m *ChannelMonitorRequestTemplate) String() string {
+ var builder strings.Builder
+ builder.WriteString("ChannelMonitorRequestTemplate(")
+ builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
+ builder.WriteString("created_at=")
+ builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("updated_at=")
+ builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("name=")
+ builder.WriteString(_m.Name)
+ builder.WriteString(", ")
+ builder.WriteString("provider=")
+ builder.WriteString(fmt.Sprintf("%v", _m.Provider))
+ builder.WriteString(", ")
+ builder.WriteString("description=")
+ builder.WriteString(_m.Description)
+ builder.WriteString(", ")
+ builder.WriteString("extra_headers=")
+ builder.WriteString(fmt.Sprintf("%v", _m.ExtraHeaders))
+ builder.WriteString(", ")
+ builder.WriteString("body_override_mode=")
+ builder.WriteString(_m.BodyOverrideMode)
+ builder.WriteString(", ")
+ builder.WriteString("body_override=")
+ builder.WriteString(fmt.Sprintf("%v", _m.BodyOverride))
+ builder.WriteByte(')')
+ return builder.String()
+}
+
+// ChannelMonitorRequestTemplates is a parsable slice of ChannelMonitorRequestTemplate.
+type ChannelMonitorRequestTemplates []*ChannelMonitorRequestTemplate
diff --git a/backend/ent/channelmonitorrequesttemplate/channelmonitorrequesttemplate.go b/backend/ent/channelmonitorrequesttemplate/channelmonitorrequesttemplate.go
new file mode 100644
index 00000000..65b8d641
--- /dev/null
+++ b/backend/ent/channelmonitorrequesttemplate/channelmonitorrequesttemplate.go
@@ -0,0 +1,172 @@
+// Code generated by ent, DO NOT EDIT.
+
+package channelmonitorrequesttemplate
+
+import (
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+)
+
+const (
+ // Label holds the string label denoting the channelmonitorrequesttemplate type in the database.
+ Label = "channel_monitor_request_template"
+ // FieldID holds the string denoting the id field in the database.
+ FieldID = "id"
+ // FieldCreatedAt holds the string denoting the created_at field in the database.
+ FieldCreatedAt = "created_at"
+ // FieldUpdatedAt holds the string denoting the updated_at field in the database.
+ FieldUpdatedAt = "updated_at"
+ // FieldName holds the string denoting the name field in the database.
+ FieldName = "name"
+ // FieldProvider holds the string denoting the provider field in the database.
+ FieldProvider = "provider"
+ // FieldDescription holds the string denoting the description field in the database.
+ FieldDescription = "description"
+ // FieldExtraHeaders holds the string denoting the extra_headers field in the database.
+ FieldExtraHeaders = "extra_headers"
+ // FieldBodyOverrideMode holds the string denoting the body_override_mode field in the database.
+ FieldBodyOverrideMode = "body_override_mode"
+ // FieldBodyOverride holds the string denoting the body_override field in the database.
+ FieldBodyOverride = "body_override"
+ // EdgeMonitors holds the string denoting the monitors edge name in mutations.
+ EdgeMonitors = "monitors"
+ // Table holds the table name of the channelmonitorrequesttemplate in the database.
+ Table = "channel_monitor_request_templates"
+ // MonitorsTable is the table that holds the monitors relation/edge.
+ MonitorsTable = "channel_monitors"
+ // MonitorsInverseTable is the table name for the ChannelMonitor entity.
+ // It exists in this package in order to avoid circular dependency with the "channelmonitor" package.
+ MonitorsInverseTable = "channel_monitors"
+ // MonitorsColumn is the table column denoting the monitors relation/edge.
+ MonitorsColumn = "template_id"
+)
+
+// Columns holds all SQL columns for channelmonitorrequesttemplate fields.
+var Columns = []string{
+ FieldID,
+ FieldCreatedAt,
+ FieldUpdatedAt,
+ FieldName,
+ FieldProvider,
+ FieldDescription,
+ FieldExtraHeaders,
+ FieldBodyOverrideMode,
+ FieldBodyOverride,
+}
+
+// ValidColumn reports if the column name is valid (part of the table columns).
+func ValidColumn(column string) bool {
+ for i := range Columns {
+ if column == Columns[i] {
+ return true
+ }
+ }
+ return false
+}
+
+var (
+ // DefaultCreatedAt holds the default value on creation for the "created_at" field.
+ DefaultCreatedAt func() time.Time
+ // DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
+ DefaultUpdatedAt func() time.Time
+ // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
+ UpdateDefaultUpdatedAt func() time.Time
+ // NameValidator is a validator for the "name" field. It is called by the builders before save.
+ NameValidator func(string) error
+ // DefaultDescription holds the default value on creation for the "description" field.
+ DefaultDescription string
+ // DescriptionValidator is a validator for the "description" field. It is called by the builders before save.
+ DescriptionValidator func(string) error
+ // DefaultExtraHeaders holds the default value on creation for the "extra_headers" field.
+ DefaultExtraHeaders map[string]string
+ // DefaultBodyOverrideMode holds the default value on creation for the "body_override_mode" field.
+ DefaultBodyOverrideMode string
+ // BodyOverrideModeValidator is a validator for the "body_override_mode" field. It is called by the builders before save.
+ BodyOverrideModeValidator func(string) error
+)
+
+// Provider defines the type for the "provider" enum field.
+type Provider string
+
+// Provider values.
+const (
+ ProviderOpenai Provider = "openai"
+ ProviderAnthropic Provider = "anthropic"
+ ProviderGemini Provider = "gemini"
+)
+
+func (pr Provider) String() string {
+ return string(pr)
+}
+
+// ProviderValidator is a validator for the "provider" field enum values. It is called by the builders before save.
+func ProviderValidator(pr Provider) error {
+ switch pr {
+ case ProviderOpenai, ProviderAnthropic, ProviderGemini:
+ return nil
+ default:
+ return fmt.Errorf("channelmonitorrequesttemplate: invalid enum value for provider field: %q", pr)
+ }
+}
+
+// OrderOption defines the ordering options for the ChannelMonitorRequestTemplate queries.
+type OrderOption func(*sql.Selector)
+
+// ByID orders the results by the id field.
+func ByID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldID, opts...).ToFunc()
+}
+
+// ByCreatedAt orders the results by the created_at field.
+func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
+}
+
+// ByUpdatedAt orders the results by the updated_at field.
+func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
+}
+
+// ByName orders the results by the name field.
+func ByName(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldName, opts...).ToFunc()
+}
+
+// ByProvider orders the results by the provider field.
+func ByProvider(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProvider, opts...).ToFunc()
+}
+
+// ByDescription orders the results by the description field.
+func ByDescription(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldDescription, opts...).ToFunc()
+}
+
+// ByBodyOverrideMode orders the results by the body_override_mode field.
+func ByBodyOverrideMode(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldBodyOverrideMode, opts...).ToFunc()
+}
+
+// ByMonitorsCount orders the results by monitors count.
+func ByMonitorsCount(opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborsCount(s, newMonitorsStep(), opts...)
+ }
+}
+
+// ByMonitors orders the results by monitors terms.
+func ByMonitors(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newMonitorsStep(), append([]sql.OrderTerm{term}, terms...)...)
+ }
+}
+func newMonitorsStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(MonitorsInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, true, MonitorsTable, MonitorsColumn),
+ )
+}
diff --git a/backend/ent/channelmonitorrequesttemplate/where.go b/backend/ent/channelmonitorrequesttemplate/where.go
new file mode 100644
index 00000000..b95e5df0
--- /dev/null
+++ b/backend/ent/channelmonitorrequesttemplate/where.go
@@ -0,0 +1,434 @@
+// Code generated by ent, DO NOT EDIT.
+
+package channelmonitorrequesttemplate
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ID filters vertices based on their ID field.
+func ID(id int64) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldID, id))
+}
+
+// IDEQ applies the EQ predicate on the ID field.
+func IDEQ(id int64) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldID, id))
+}
+
+// IDNEQ applies the NEQ predicate on the ID field.
+func IDNEQ(id int64) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNEQ(FieldID, id))
+}
+
+// IDIn applies the In predicate on the ID field.
+func IDIn(ids ...int64) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldIn(FieldID, ids...))
+}
+
+// IDNotIn applies the NotIn predicate on the ID field.
+func IDNotIn(ids ...int64) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNotIn(FieldID, ids...))
+}
+
+// IDGT applies the GT predicate on the ID field.
+func IDGT(id int64) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldGT(FieldID, id))
+}
+
+// IDGTE applies the GTE predicate on the ID field.
+func IDGTE(id int64) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldGTE(FieldID, id))
+}
+
+// IDLT applies the LT predicate on the ID field.
+func IDLT(id int64) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldLT(FieldID, id))
+}
+
+// IDLTE applies the LTE predicate on the ID field.
+func IDLTE(id int64) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldLTE(FieldID, id))
+}
+
+// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
+func CreatedAt(v time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
+func UpdatedAt(v time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// Name applies equality check predicate on the "name" field. It's identical to NameEQ.
+func Name(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldName, v))
+}
+
+// Description applies equality check predicate on the "description" field. It's identical to DescriptionEQ.
+func Description(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldDescription, v))
+}
+
+// BodyOverrideMode applies equality check predicate on the "body_override_mode" field. It's identical to BodyOverrideModeEQ.
+func BodyOverrideMode(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldBodyOverrideMode, v))
+}
+
+// CreatedAtEQ applies the EQ predicate on the "created_at" field.
+func CreatedAtEQ(v time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
+func CreatedAtNEQ(v time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtIn applies the In predicate on the "created_at" field.
+func CreatedAtIn(vs ...time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
+func CreatedAtNotIn(vs ...time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNotIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtGT applies the GT predicate on the "created_at" field.
+func CreatedAtGT(v time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldGT(FieldCreatedAt, v))
+}
+
+// CreatedAtGTE applies the GTE predicate on the "created_at" field.
+func CreatedAtGTE(v time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldGTE(FieldCreatedAt, v))
+}
+
+// CreatedAtLT applies the LT predicate on the "created_at" field.
+func CreatedAtLT(v time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldLT(FieldCreatedAt, v))
+}
+
+// CreatedAtLTE applies the LTE predicate on the "created_at" field.
+func CreatedAtLTE(v time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldLTE(FieldCreatedAt, v))
+}
+
+// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
+func UpdatedAtEQ(v time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
+func UpdatedAtNEQ(v time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtIn applies the In predicate on the "updated_at" field.
+func UpdatedAtIn(vs ...time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
+func UpdatedAtNotIn(vs ...time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNotIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtGT applies the GT predicate on the "updated_at" field.
+func UpdatedAtGT(v time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldGT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
+func UpdatedAtGTE(v time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldGTE(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLT applies the LT predicate on the "updated_at" field.
+func UpdatedAtLT(v time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldLT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
+func UpdatedAtLTE(v time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldLTE(FieldUpdatedAt, v))
+}
+
+// NameEQ applies the EQ predicate on the "name" field.
+func NameEQ(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldName, v))
+}
+
+// NameNEQ applies the NEQ predicate on the "name" field.
+func NameNEQ(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNEQ(FieldName, v))
+}
+
+// NameIn applies the In predicate on the "name" field.
+func NameIn(vs ...string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldIn(FieldName, vs...))
+}
+
+// NameNotIn applies the NotIn predicate on the "name" field.
+func NameNotIn(vs ...string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNotIn(FieldName, vs...))
+}
+
+// NameGT applies the GT predicate on the "name" field.
+func NameGT(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldGT(FieldName, v))
+}
+
+// NameGTE applies the GTE predicate on the "name" field.
+func NameGTE(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldGTE(FieldName, v))
+}
+
+// NameLT applies the LT predicate on the "name" field.
+func NameLT(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldLT(FieldName, v))
+}
+
+// NameLTE applies the LTE predicate on the "name" field.
+func NameLTE(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldLTE(FieldName, v))
+}
+
+// NameContains applies the Contains predicate on the "name" field.
+func NameContains(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldContains(FieldName, v))
+}
+
+// NameHasPrefix applies the HasPrefix predicate on the "name" field.
+func NameHasPrefix(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldHasPrefix(FieldName, v))
+}
+
+// NameHasSuffix applies the HasSuffix predicate on the "name" field.
+func NameHasSuffix(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldHasSuffix(FieldName, v))
+}
+
+// NameEqualFold applies the EqualFold predicate on the "name" field.
+func NameEqualFold(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEqualFold(FieldName, v))
+}
+
+// NameContainsFold applies the ContainsFold predicate on the "name" field.
+func NameContainsFold(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldContainsFold(FieldName, v))
+}
+
+// ProviderEQ applies the EQ predicate on the "provider" field.
+func ProviderEQ(v Provider) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldProvider, v))
+}
+
+// ProviderNEQ applies the NEQ predicate on the "provider" field.
+func ProviderNEQ(v Provider) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNEQ(FieldProvider, v))
+}
+
+// ProviderIn applies the In predicate on the "provider" field.
+func ProviderIn(vs ...Provider) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldIn(FieldProvider, vs...))
+}
+
+// ProviderNotIn applies the NotIn predicate on the "provider" field.
+func ProviderNotIn(vs ...Provider) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNotIn(FieldProvider, vs...))
+}
+
+// DescriptionEQ applies the EQ predicate on the "description" field.
+func DescriptionEQ(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldDescription, v))
+}
+
+// DescriptionNEQ applies the NEQ predicate on the "description" field.
+func DescriptionNEQ(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNEQ(FieldDescription, v))
+}
+
+// DescriptionIn applies the In predicate on the "description" field.
+func DescriptionIn(vs ...string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldIn(FieldDescription, vs...))
+}
+
+// DescriptionNotIn applies the NotIn predicate on the "description" field.
+func DescriptionNotIn(vs ...string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNotIn(FieldDescription, vs...))
+}
+
+// DescriptionGT applies the GT predicate on the "description" field.
+func DescriptionGT(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldGT(FieldDescription, v))
+}
+
+// DescriptionGTE applies the GTE predicate on the "description" field.
+func DescriptionGTE(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldGTE(FieldDescription, v))
+}
+
+// DescriptionLT applies the LT predicate on the "description" field.
+func DescriptionLT(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldLT(FieldDescription, v))
+}
+
+// DescriptionLTE applies the LTE predicate on the "description" field.
+func DescriptionLTE(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldLTE(FieldDescription, v))
+}
+
+// DescriptionContains applies the Contains predicate on the "description" field.
+func DescriptionContains(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldContains(FieldDescription, v))
+}
+
+// DescriptionHasPrefix applies the HasPrefix predicate on the "description" field.
+func DescriptionHasPrefix(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldHasPrefix(FieldDescription, v))
+}
+
+// DescriptionHasSuffix applies the HasSuffix predicate on the "description" field.
+func DescriptionHasSuffix(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldHasSuffix(FieldDescription, v))
+}
+
+// DescriptionIsNil applies the IsNil predicate on the "description" field.
+func DescriptionIsNil() predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldIsNull(FieldDescription))
+}
+
+// DescriptionNotNil applies the NotNil predicate on the "description" field.
+func DescriptionNotNil() predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNotNull(FieldDescription))
+}
+
+// DescriptionEqualFold applies the EqualFold predicate on the "description" field.
+func DescriptionEqualFold(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEqualFold(FieldDescription, v))
+}
+
+// DescriptionContainsFold applies the ContainsFold predicate on the "description" field.
+func DescriptionContainsFold(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldContainsFold(FieldDescription, v))
+}
+
+// BodyOverrideModeEQ applies the EQ predicate on the "body_override_mode" field.
+func BodyOverrideModeEQ(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeNEQ applies the NEQ predicate on the "body_override_mode" field.
+func BodyOverrideModeNEQ(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNEQ(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeIn applies the In predicate on the "body_override_mode" field.
+func BodyOverrideModeIn(vs ...string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldIn(FieldBodyOverrideMode, vs...))
+}
+
+// BodyOverrideModeNotIn applies the NotIn predicate on the "body_override_mode" field.
+func BodyOverrideModeNotIn(vs ...string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNotIn(FieldBodyOverrideMode, vs...))
+}
+
+// BodyOverrideModeGT applies the GT predicate on the "body_override_mode" field.
+func BodyOverrideModeGT(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldGT(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeGTE applies the GTE predicate on the "body_override_mode" field.
+func BodyOverrideModeGTE(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldGTE(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeLT applies the LT predicate on the "body_override_mode" field.
+func BodyOverrideModeLT(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldLT(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeLTE applies the LTE predicate on the "body_override_mode" field.
+func BodyOverrideModeLTE(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldLTE(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeContains applies the Contains predicate on the "body_override_mode" field.
+func BodyOverrideModeContains(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldContains(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeHasPrefix applies the HasPrefix predicate on the "body_override_mode" field.
+func BodyOverrideModeHasPrefix(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldHasPrefix(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeHasSuffix applies the HasSuffix predicate on the "body_override_mode" field.
+func BodyOverrideModeHasSuffix(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldHasSuffix(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeEqualFold applies the EqualFold predicate on the "body_override_mode" field.
+func BodyOverrideModeEqualFold(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEqualFold(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeContainsFold applies the ContainsFold predicate on the "body_override_mode" field.
+func BodyOverrideModeContainsFold(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldContainsFold(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideIsNil applies the IsNil predicate on the "body_override" field.
+func BodyOverrideIsNil() predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldIsNull(FieldBodyOverride))
+}
+
+// BodyOverrideNotNil applies the NotNil predicate on the "body_override" field.
+func BodyOverrideNotNil() predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNotNull(FieldBodyOverride))
+}
+
+// HasMonitors applies the HasEdge predicate on the "monitors" edge.
+func HasMonitors() predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, true, MonitorsTable, MonitorsColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasMonitorsWith applies the HasEdge predicate on the "monitors" edge with a given conditions (other predicates).
+func HasMonitorsWith(preds ...predicate.ChannelMonitor) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(func(s *sql.Selector) {
+ step := newMonitorsStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// And groups predicates with the AND operator between them.
+func And(predicates ...predicate.ChannelMonitorRequestTemplate) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.AndPredicates(predicates...))
+}
+
+// Or groups predicates with the OR operator between them.
+func Or(predicates ...predicate.ChannelMonitorRequestTemplate) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.OrPredicates(predicates...))
+}
+
+// Not applies the not operator on the given predicate.
+func Not(p predicate.ChannelMonitorRequestTemplate) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.NotPredicates(p))
+}
diff --git a/backend/ent/channelmonitorrequesttemplate_create.go b/backend/ent/channelmonitorrequesttemplate_create.go
new file mode 100644
index 00000000..1ba842cd
--- /dev/null
+++ b/backend/ent/channelmonitorrequesttemplate_create.go
@@ -0,0 +1,942 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate"
+)
+
+// ChannelMonitorRequestTemplateCreate is the builder for creating a ChannelMonitorRequestTemplate entity.
+type ChannelMonitorRequestTemplateCreate struct {
+ config
+ mutation *ChannelMonitorRequestTemplateMutation
+ hooks []Hook
+ conflict []sql.ConflictOption
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (_c *ChannelMonitorRequestTemplateCreate) SetCreatedAt(v time.Time) *ChannelMonitorRequestTemplateCreate {
+ _c.mutation.SetCreatedAt(v)
+ return _c
+}
+
+// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
+func (_c *ChannelMonitorRequestTemplateCreate) SetNillableCreatedAt(v *time.Time) *ChannelMonitorRequestTemplateCreate {
+ if v != nil {
+ _c.SetCreatedAt(*v)
+ }
+ return _c
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_c *ChannelMonitorRequestTemplateCreate) SetUpdatedAt(v time.Time) *ChannelMonitorRequestTemplateCreate {
+ _c.mutation.SetUpdatedAt(v)
+ return _c
+}
+
+// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil.
+func (_c *ChannelMonitorRequestTemplateCreate) SetNillableUpdatedAt(v *time.Time) *ChannelMonitorRequestTemplateCreate {
+ if v != nil {
+ _c.SetUpdatedAt(*v)
+ }
+ return _c
+}
+
+// SetName sets the "name" field.
+func (_c *ChannelMonitorRequestTemplateCreate) SetName(v string) *ChannelMonitorRequestTemplateCreate {
+ _c.mutation.SetName(v)
+ return _c
+}
+
+// SetProvider sets the "provider" field.
+func (_c *ChannelMonitorRequestTemplateCreate) SetProvider(v channelmonitorrequesttemplate.Provider) *ChannelMonitorRequestTemplateCreate {
+ _c.mutation.SetProvider(v)
+ return _c
+}
+
+// SetDescription sets the "description" field.
+func (_c *ChannelMonitorRequestTemplateCreate) SetDescription(v string) *ChannelMonitorRequestTemplateCreate {
+ _c.mutation.SetDescription(v)
+ return _c
+}
+
+// SetNillableDescription sets the "description" field if the given value is not nil.
+func (_c *ChannelMonitorRequestTemplateCreate) SetNillableDescription(v *string) *ChannelMonitorRequestTemplateCreate {
+ if v != nil {
+ _c.SetDescription(*v)
+ }
+ return _c
+}
+
+// SetExtraHeaders sets the "extra_headers" field.
+func (_c *ChannelMonitorRequestTemplateCreate) SetExtraHeaders(v map[string]string) *ChannelMonitorRequestTemplateCreate {
+ _c.mutation.SetExtraHeaders(v)
+ return _c
+}
+
+// SetBodyOverrideMode sets the "body_override_mode" field.
+func (_c *ChannelMonitorRequestTemplateCreate) SetBodyOverrideMode(v string) *ChannelMonitorRequestTemplateCreate {
+ _c.mutation.SetBodyOverrideMode(v)
+ return _c
+}
+
+// SetNillableBodyOverrideMode sets the "body_override_mode" field if the given value is not nil.
+func (_c *ChannelMonitorRequestTemplateCreate) SetNillableBodyOverrideMode(v *string) *ChannelMonitorRequestTemplateCreate {
+ if v != nil {
+ _c.SetBodyOverrideMode(*v)
+ }
+ return _c
+}
+
+// SetBodyOverride sets the "body_override" field.
+func (_c *ChannelMonitorRequestTemplateCreate) SetBodyOverride(v map[string]interface{}) *ChannelMonitorRequestTemplateCreate {
+ _c.mutation.SetBodyOverride(v)
+ return _c
+}
+
+// AddMonitorIDs adds the "monitors" edge to the ChannelMonitor entity by IDs.
+func (_c *ChannelMonitorRequestTemplateCreate) AddMonitorIDs(ids ...int64) *ChannelMonitorRequestTemplateCreate {
+ _c.mutation.AddMonitorIDs(ids...)
+ return _c
+}
+
+// AddMonitors adds the "monitors" edges to the ChannelMonitor entity.
+func (_c *ChannelMonitorRequestTemplateCreate) AddMonitors(v ...*ChannelMonitor) *ChannelMonitorRequestTemplateCreate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _c.AddMonitorIDs(ids...)
+}
+
+// Mutation returns the ChannelMonitorRequestTemplateMutation object of the builder.
+func (_c *ChannelMonitorRequestTemplateCreate) Mutation() *ChannelMonitorRequestTemplateMutation {
+ return _c.mutation
+}
+
+// Save creates the ChannelMonitorRequestTemplate in the database.
+func (_c *ChannelMonitorRequestTemplateCreate) Save(ctx context.Context) (*ChannelMonitorRequestTemplate, error) {
+ _c.defaults()
+ return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
+}
+
+// SaveX calls Save and panics if Save returns an error.
+func (_c *ChannelMonitorRequestTemplateCreate) SaveX(ctx context.Context) *ChannelMonitorRequestTemplate {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *ChannelMonitorRequestTemplateCreate) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *ChannelMonitorRequestTemplateCreate) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_c *ChannelMonitorRequestTemplateCreate) defaults() {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ v := channelmonitorrequesttemplate.DefaultCreatedAt()
+ _c.mutation.SetCreatedAt(v)
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ v := channelmonitorrequesttemplate.DefaultUpdatedAt()
+ _c.mutation.SetUpdatedAt(v)
+ }
+ if _, ok := _c.mutation.Description(); !ok {
+ v := channelmonitorrequesttemplate.DefaultDescription
+ _c.mutation.SetDescription(v)
+ }
+ if _, ok := _c.mutation.ExtraHeaders(); !ok {
+ v := channelmonitorrequesttemplate.DefaultExtraHeaders
+ _c.mutation.SetExtraHeaders(v)
+ }
+ if _, ok := _c.mutation.BodyOverrideMode(); !ok {
+ v := channelmonitorrequesttemplate.DefaultBodyOverrideMode
+ _c.mutation.SetBodyOverrideMode(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_c *ChannelMonitorRequestTemplateCreate) check() error {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "ChannelMonitorRequestTemplate.created_at"`)}
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "ChannelMonitorRequestTemplate.updated_at"`)}
+ }
+ if _, ok := _c.mutation.Name(); !ok {
+ return &ValidationError{Name: "name", err: errors.New(`ent: missing required field "ChannelMonitorRequestTemplate.name"`)}
+ }
+ if v, ok := _c.mutation.Name(); ok {
+ if err := channelmonitorrequesttemplate.NameValidator(v); err != nil {
+ return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.name": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.Provider(); !ok {
+ return &ValidationError{Name: "provider", err: errors.New(`ent: missing required field "ChannelMonitorRequestTemplate.provider"`)}
+ }
+ if v, ok := _c.mutation.Provider(); ok {
+ if err := channelmonitorrequesttemplate.ProviderValidator(v); err != nil {
+ return &ValidationError{Name: "provider", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.provider": %w`, err)}
+ }
+ }
+ if v, ok := _c.mutation.Description(); ok {
+ if err := channelmonitorrequesttemplate.DescriptionValidator(v); err != nil {
+ return &ValidationError{Name: "description", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.description": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ExtraHeaders(); !ok {
+ return &ValidationError{Name: "extra_headers", err: errors.New(`ent: missing required field "ChannelMonitorRequestTemplate.extra_headers"`)}
+ }
+ if _, ok := _c.mutation.BodyOverrideMode(); !ok {
+ return &ValidationError{Name: "body_override_mode", err: errors.New(`ent: missing required field "ChannelMonitorRequestTemplate.body_override_mode"`)}
+ }
+ if v, ok := _c.mutation.BodyOverrideMode(); ok {
+ if err := channelmonitorrequesttemplate.BodyOverrideModeValidator(v); err != nil {
+ return &ValidationError{Name: "body_override_mode", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.body_override_mode": %w`, err)}
+ }
+ }
+ return nil
+}
+
+func (_c *ChannelMonitorRequestTemplateCreate) sqlSave(ctx context.Context) (*ChannelMonitorRequestTemplate, error) {
+ if err := _c.check(); err != nil {
+ return nil, err
+ }
+ _node, _spec := _c.createSpec()
+ if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ id := _spec.ID.Value.(int64)
+ _node.ID = int64(id)
+ _c.mutation.id = &_node.ID
+ _c.mutation.done = true
+ return _node, nil
+}
+
+func (_c *ChannelMonitorRequestTemplateCreate) createSpec() (*ChannelMonitorRequestTemplate, *sqlgraph.CreateSpec) {
+ var (
+ _node = &ChannelMonitorRequestTemplate{config: _c.config}
+ _spec = sqlgraph.NewCreateSpec(channelmonitorrequesttemplate.Table, sqlgraph.NewFieldSpec(channelmonitorrequesttemplate.FieldID, field.TypeInt64))
+ )
+ _spec.OnConflict = _c.conflict
+ if value, ok := _c.mutation.CreatedAt(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldCreatedAt, field.TypeTime, value)
+ _node.CreatedAt = value
+ }
+ if value, ok := _c.mutation.UpdatedAt(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldUpdatedAt, field.TypeTime, value)
+ _node.UpdatedAt = value
+ }
+ if value, ok := _c.mutation.Name(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldName, field.TypeString, value)
+ _node.Name = value
+ }
+ if value, ok := _c.mutation.Provider(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldProvider, field.TypeEnum, value)
+ _node.Provider = value
+ }
+ if value, ok := _c.mutation.Description(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldDescription, field.TypeString, value)
+ _node.Description = value
+ }
+ if value, ok := _c.mutation.ExtraHeaders(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldExtraHeaders, field.TypeJSON, value)
+ _node.ExtraHeaders = value
+ }
+ if value, ok := _c.mutation.BodyOverrideMode(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldBodyOverrideMode, field.TypeString, value)
+ _node.BodyOverrideMode = value
+ }
+ if value, ok := _c.mutation.BodyOverride(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldBodyOverride, field.TypeJSON, value)
+ _node.BodyOverride = value
+ }
+ if nodes := _c.mutation.MonitorsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: true,
+ Table: channelmonitorrequesttemplate.MonitorsTable,
+ Columns: []string{channelmonitorrequesttemplate.MonitorsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ return _node, _spec
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.ChannelMonitorRequestTemplate.Create().
+// SetCreatedAt(v).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.ChannelMonitorRequestTemplateUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *ChannelMonitorRequestTemplateCreate) OnConflict(opts ...sql.ConflictOption) *ChannelMonitorRequestTemplateUpsertOne {
+ _c.conflict = opts
+ return &ChannelMonitorRequestTemplateUpsertOne{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.ChannelMonitorRequestTemplate.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *ChannelMonitorRequestTemplateCreate) OnConflictColumns(columns ...string) *ChannelMonitorRequestTemplateUpsertOne {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &ChannelMonitorRequestTemplateUpsertOne{
+ create: _c,
+ }
+}
+
+type (
+ // ChannelMonitorRequestTemplateUpsertOne is the builder for "upsert"-ing
+ // one ChannelMonitorRequestTemplate node.
+ ChannelMonitorRequestTemplateUpsertOne struct {
+ create *ChannelMonitorRequestTemplateCreate
+ }
+
+ // ChannelMonitorRequestTemplateUpsert is the "OnConflict" setter.
+ ChannelMonitorRequestTemplateUpsert struct {
+ *sql.UpdateSet
+ }
+)
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *ChannelMonitorRequestTemplateUpsert) SetUpdatedAt(v time.Time) *ChannelMonitorRequestTemplateUpsert {
+ u.Set(channelmonitorrequesttemplate.FieldUpdatedAt, v)
+ return u
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsert) UpdateUpdatedAt() *ChannelMonitorRequestTemplateUpsert {
+ u.SetExcluded(channelmonitorrequesttemplate.FieldUpdatedAt)
+ return u
+}
+
+// SetName sets the "name" field.
+func (u *ChannelMonitorRequestTemplateUpsert) SetName(v string) *ChannelMonitorRequestTemplateUpsert {
+ u.Set(channelmonitorrequesttemplate.FieldName, v)
+ return u
+}
+
+// UpdateName sets the "name" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsert) UpdateName() *ChannelMonitorRequestTemplateUpsert {
+ u.SetExcluded(channelmonitorrequesttemplate.FieldName)
+ return u
+}
+
+// SetProvider sets the "provider" field.
+func (u *ChannelMonitorRequestTemplateUpsert) SetProvider(v channelmonitorrequesttemplate.Provider) *ChannelMonitorRequestTemplateUpsert {
+ u.Set(channelmonitorrequesttemplate.FieldProvider, v)
+ return u
+}
+
+// UpdateProvider sets the "provider" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsert) UpdateProvider() *ChannelMonitorRequestTemplateUpsert {
+ u.SetExcluded(channelmonitorrequesttemplate.FieldProvider)
+ return u
+}
+
+// SetDescription sets the "description" field.
+func (u *ChannelMonitorRequestTemplateUpsert) SetDescription(v string) *ChannelMonitorRequestTemplateUpsert {
+ u.Set(channelmonitorrequesttemplate.FieldDescription, v)
+ return u
+}
+
+// UpdateDescription sets the "description" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsert) UpdateDescription() *ChannelMonitorRequestTemplateUpsert {
+ u.SetExcluded(channelmonitorrequesttemplate.FieldDescription)
+ return u
+}
+
+// ClearDescription clears the value of the "description" field.
+func (u *ChannelMonitorRequestTemplateUpsert) ClearDescription() *ChannelMonitorRequestTemplateUpsert {
+ u.SetNull(channelmonitorrequesttemplate.FieldDescription)
+ return u
+}
+
+// SetExtraHeaders sets the "extra_headers" field.
+func (u *ChannelMonitorRequestTemplateUpsert) SetExtraHeaders(v map[string]string) *ChannelMonitorRequestTemplateUpsert {
+ u.Set(channelmonitorrequesttemplate.FieldExtraHeaders, v)
+ return u
+}
+
+// UpdateExtraHeaders sets the "extra_headers" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsert) UpdateExtraHeaders() *ChannelMonitorRequestTemplateUpsert {
+ u.SetExcluded(channelmonitorrequesttemplate.FieldExtraHeaders)
+ return u
+}
+
+// SetBodyOverrideMode sets the "body_override_mode" field.
+func (u *ChannelMonitorRequestTemplateUpsert) SetBodyOverrideMode(v string) *ChannelMonitorRequestTemplateUpsert {
+ u.Set(channelmonitorrequesttemplate.FieldBodyOverrideMode, v)
+ return u
+}
+
+// UpdateBodyOverrideMode sets the "body_override_mode" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsert) UpdateBodyOverrideMode() *ChannelMonitorRequestTemplateUpsert {
+ u.SetExcluded(channelmonitorrequesttemplate.FieldBodyOverrideMode)
+ return u
+}
+
+// SetBodyOverride sets the "body_override" field.
+func (u *ChannelMonitorRequestTemplateUpsert) SetBodyOverride(v map[string]interface{}) *ChannelMonitorRequestTemplateUpsert {
+ u.Set(channelmonitorrequesttemplate.FieldBodyOverride, v)
+ return u
+}
+
+// UpdateBodyOverride sets the "body_override" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsert) UpdateBodyOverride() *ChannelMonitorRequestTemplateUpsert {
+ u.SetExcluded(channelmonitorrequesttemplate.FieldBodyOverride)
+ return u
+}
+
+// ClearBodyOverride clears the value of the "body_override" field.
+func (u *ChannelMonitorRequestTemplateUpsert) ClearBodyOverride() *ChannelMonitorRequestTemplateUpsert {
+ u.SetNull(channelmonitorrequesttemplate.FieldBodyOverride)
+ return u
+}
+
+// UpdateNewValues updates the mutable fields using the new values that were set on create.
+// Using this option is equivalent to using:
+//
+// client.ChannelMonitorRequestTemplate.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *ChannelMonitorRequestTemplateUpsertOne) UpdateNewValues() *ChannelMonitorRequestTemplateUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ if _, exists := u.create.mutation.CreatedAt(); exists {
+ s.SetIgnore(channelmonitorrequesttemplate.FieldCreatedAt)
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.ChannelMonitorRequestTemplate.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *ChannelMonitorRequestTemplateUpsertOne) Ignore() *ChannelMonitorRequestTemplateUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *ChannelMonitorRequestTemplateUpsertOne) DoNothing() *ChannelMonitorRequestTemplateUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the ChannelMonitorRequestTemplateCreate.OnConflict
+// documentation for more info.
+func (u *ChannelMonitorRequestTemplateUpsertOne) Update(set func(*ChannelMonitorRequestTemplateUpsert)) *ChannelMonitorRequestTemplateUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&ChannelMonitorRequestTemplateUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *ChannelMonitorRequestTemplateUpsertOne) SetUpdatedAt(v time.Time) *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsertOne) UpdateUpdatedAt() *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetName sets the "name" field.
+func (u *ChannelMonitorRequestTemplateUpsertOne) SetName(v string) *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.SetName(v)
+ })
+}
+
+// UpdateName sets the "name" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsertOne) UpdateName() *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.UpdateName()
+ })
+}
+
+// SetProvider sets the "provider" field.
+func (u *ChannelMonitorRequestTemplateUpsertOne) SetProvider(v channelmonitorrequesttemplate.Provider) *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.SetProvider(v)
+ })
+}
+
+// UpdateProvider sets the "provider" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsertOne) UpdateProvider() *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.UpdateProvider()
+ })
+}
+
+// SetDescription sets the "description" field.
+func (u *ChannelMonitorRequestTemplateUpsertOne) SetDescription(v string) *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.SetDescription(v)
+ })
+}
+
+// UpdateDescription sets the "description" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsertOne) UpdateDescription() *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.UpdateDescription()
+ })
+}
+
+// ClearDescription clears the value of the "description" field.
+func (u *ChannelMonitorRequestTemplateUpsertOne) ClearDescription() *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.ClearDescription()
+ })
+}
+
+// SetExtraHeaders sets the "extra_headers" field.
+func (u *ChannelMonitorRequestTemplateUpsertOne) SetExtraHeaders(v map[string]string) *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.SetExtraHeaders(v)
+ })
+}
+
+// UpdateExtraHeaders sets the "extra_headers" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsertOne) UpdateExtraHeaders() *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.UpdateExtraHeaders()
+ })
+}
+
+// SetBodyOverrideMode sets the "body_override_mode" field.
+func (u *ChannelMonitorRequestTemplateUpsertOne) SetBodyOverrideMode(v string) *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.SetBodyOverrideMode(v)
+ })
+}
+
+// UpdateBodyOverrideMode sets the "body_override_mode" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsertOne) UpdateBodyOverrideMode() *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.UpdateBodyOverrideMode()
+ })
+}
+
+// SetBodyOverride sets the "body_override" field.
+func (u *ChannelMonitorRequestTemplateUpsertOne) SetBodyOverride(v map[string]interface{}) *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.SetBodyOverride(v)
+ })
+}
+
+// UpdateBodyOverride sets the "body_override" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsertOne) UpdateBodyOverride() *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.UpdateBodyOverride()
+ })
+}
+
+// ClearBodyOverride clears the value of the "body_override" field.
+func (u *ChannelMonitorRequestTemplateUpsertOne) ClearBodyOverride() *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.ClearBodyOverride()
+ })
+}
+
+// Exec executes the query.
+func (u *ChannelMonitorRequestTemplateUpsertOne) Exec(ctx context.Context) error {
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for ChannelMonitorRequestTemplateCreate.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *ChannelMonitorRequestTemplateUpsertOne) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// Exec executes the UPSERT query and returns the inserted/updated ID.
+func (u *ChannelMonitorRequestTemplateUpsertOne) ID(ctx context.Context) (id int64, err error) {
+ node, err := u.create.Save(ctx)
+ if err != nil {
+ return id, err
+ }
+ return node.ID, nil
+}
+
+// IDX is like ID, but panics if an error occurs.
+func (u *ChannelMonitorRequestTemplateUpsertOne) IDX(ctx context.Context) int64 {
+ id, err := u.ID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// ChannelMonitorRequestTemplateCreateBulk is the builder for creating many ChannelMonitorRequestTemplate entities in bulk.
+type ChannelMonitorRequestTemplateCreateBulk struct {
+ config
+ err error
+ builders []*ChannelMonitorRequestTemplateCreate
+ conflict []sql.ConflictOption
+}
+
+// Save creates the ChannelMonitorRequestTemplate entities in the database.
+func (_c *ChannelMonitorRequestTemplateCreateBulk) Save(ctx context.Context) ([]*ChannelMonitorRequestTemplate, error) {
+ if _c.err != nil {
+ return nil, _c.err
+ }
+ specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
+ nodes := make([]*ChannelMonitorRequestTemplate, len(_c.builders))
+ mutators := make([]Mutator, len(_c.builders))
+ for i := range _c.builders {
+ func(i int, root context.Context) {
+ builder := _c.builders[i]
+ builder.defaults()
+ var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
+ mutation, ok := m.(*ChannelMonitorRequestTemplateMutation)
+ if !ok {
+ return nil, fmt.Errorf("unexpected mutation type %T", m)
+ }
+ if err := builder.check(); err != nil {
+ return nil, err
+ }
+ builder.mutation = mutation
+ var err error
+ nodes[i], specs[i] = builder.createSpec()
+ if i < len(mutators)-1 {
+ _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
+ } else {
+ spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
+ spec.OnConflict = _c.conflict
+ // Invoke the actual operation on the latest mutation in the chain.
+ if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ }
+ }
+ if err != nil {
+ return nil, err
+ }
+ mutation.id = &nodes[i].ID
+ if specs[i].ID.Value != nil {
+ id := specs[i].ID.Value.(int64)
+ nodes[i].ID = int64(id)
+ }
+ mutation.done = true
+ return nodes[i], nil
+ })
+ for i := len(builder.hooks) - 1; i >= 0; i-- {
+ mut = builder.hooks[i](mut)
+ }
+ mutators[i] = mut
+ }(i, ctx)
+ }
+ if len(mutators) > 0 {
+ if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_c *ChannelMonitorRequestTemplateCreateBulk) SaveX(ctx context.Context) []*ChannelMonitorRequestTemplate {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *ChannelMonitorRequestTemplateCreateBulk) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *ChannelMonitorRequestTemplateCreateBulk) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.ChannelMonitorRequestTemplate.CreateBulk(builders...).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.ChannelMonitorRequestTemplateUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *ChannelMonitorRequestTemplateCreateBulk) OnConflict(opts ...sql.ConflictOption) *ChannelMonitorRequestTemplateUpsertBulk {
+ _c.conflict = opts
+ return &ChannelMonitorRequestTemplateUpsertBulk{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.ChannelMonitorRequestTemplate.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *ChannelMonitorRequestTemplateCreateBulk) OnConflictColumns(columns ...string) *ChannelMonitorRequestTemplateUpsertBulk {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &ChannelMonitorRequestTemplateUpsertBulk{
+ create: _c,
+ }
+}
+
+// ChannelMonitorRequestTemplateUpsertBulk is the builder for "upsert"-ing
+// a bulk of ChannelMonitorRequestTemplate nodes.
+type ChannelMonitorRequestTemplateUpsertBulk struct {
+ create *ChannelMonitorRequestTemplateCreateBulk
+}
+
+// UpdateNewValues updates the mutable fields using the new values that
+// were set on create. Using this option is equivalent to using:
+//
+// client.ChannelMonitorRequestTemplate.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *ChannelMonitorRequestTemplateUpsertBulk) UpdateNewValues() *ChannelMonitorRequestTemplateUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ for _, b := range u.create.builders {
+ if _, exists := b.mutation.CreatedAt(); exists {
+ s.SetIgnore(channelmonitorrequesttemplate.FieldCreatedAt)
+ }
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.ChannelMonitorRequestTemplate.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *ChannelMonitorRequestTemplateUpsertBulk) Ignore() *ChannelMonitorRequestTemplateUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) DoNothing() *ChannelMonitorRequestTemplateUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the ChannelMonitorRequestTemplateCreateBulk.OnConflict
+// documentation for more info.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) Update(set func(*ChannelMonitorRequestTemplateUpsert)) *ChannelMonitorRequestTemplateUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&ChannelMonitorRequestTemplateUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) SetUpdatedAt(v time.Time) *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) UpdateUpdatedAt() *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetName sets the "name" field.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) SetName(v string) *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.SetName(v)
+ })
+}
+
+// UpdateName sets the "name" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) UpdateName() *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.UpdateName()
+ })
+}
+
+// SetProvider sets the "provider" field.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) SetProvider(v channelmonitorrequesttemplate.Provider) *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.SetProvider(v)
+ })
+}
+
+// UpdateProvider sets the "provider" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) UpdateProvider() *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.UpdateProvider()
+ })
+}
+
+// SetDescription sets the "description" field.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) SetDescription(v string) *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.SetDescription(v)
+ })
+}
+
+// UpdateDescription sets the "description" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) UpdateDescription() *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.UpdateDescription()
+ })
+}
+
+// ClearDescription clears the value of the "description" field.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) ClearDescription() *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.ClearDescription()
+ })
+}
+
+// SetExtraHeaders sets the "extra_headers" field.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) SetExtraHeaders(v map[string]string) *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.SetExtraHeaders(v)
+ })
+}
+
+// UpdateExtraHeaders sets the "extra_headers" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) UpdateExtraHeaders() *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.UpdateExtraHeaders()
+ })
+}
+
+// SetBodyOverrideMode sets the "body_override_mode" field.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) SetBodyOverrideMode(v string) *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.SetBodyOverrideMode(v)
+ })
+}
+
+// UpdateBodyOverrideMode sets the "body_override_mode" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) UpdateBodyOverrideMode() *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.UpdateBodyOverrideMode()
+ })
+}
+
+// SetBodyOverride sets the "body_override" field.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) SetBodyOverride(v map[string]interface{}) *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.SetBodyOverride(v)
+ })
+}
+
+// UpdateBodyOverride sets the "body_override" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) UpdateBodyOverride() *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.UpdateBodyOverride()
+ })
+}
+
+// ClearBodyOverride clears the value of the "body_override" field.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) ClearBodyOverride() *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.ClearBodyOverride()
+ })
+}
+
+// Exec executes the query.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) Exec(ctx context.Context) error {
+ if u.create.err != nil {
+ return u.create.err
+ }
+ for i, b := range u.create.builders {
+ if len(b.conflict) != 0 {
+ return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the ChannelMonitorRequestTemplateCreateBulk instead", i)
+ }
+ }
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for ChannelMonitorRequestTemplateCreateBulk.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/channelmonitorrequesttemplate_delete.go b/backend/ent/channelmonitorrequesttemplate_delete.go
new file mode 100644
index 00000000..98d365c8
--- /dev/null
+++ b/backend/ent/channelmonitorrequesttemplate_delete.go
@@ -0,0 +1,88 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ChannelMonitorRequestTemplateDelete is the builder for deleting a ChannelMonitorRequestTemplate entity.
+type ChannelMonitorRequestTemplateDelete struct {
+ config
+ hooks []Hook
+ mutation *ChannelMonitorRequestTemplateMutation
+}
+
+// Where appends a list predicates to the ChannelMonitorRequestTemplateDelete builder.
+func (_d *ChannelMonitorRequestTemplateDelete) Where(ps ...predicate.ChannelMonitorRequestTemplate) *ChannelMonitorRequestTemplateDelete {
+ _d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query and returns how many vertices were deleted.
+func (_d *ChannelMonitorRequestTemplateDelete) Exec(ctx context.Context) (int, error) {
+ return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *ChannelMonitorRequestTemplateDelete) ExecX(ctx context.Context) int {
+ n, err := _d.Exec(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return n
+}
+
+func (_d *ChannelMonitorRequestTemplateDelete) sqlExec(ctx context.Context) (int, error) {
+ _spec := sqlgraph.NewDeleteSpec(channelmonitorrequesttemplate.Table, sqlgraph.NewFieldSpec(channelmonitorrequesttemplate.FieldID, field.TypeInt64))
+ if ps := _d.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
+ if err != nil && sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ _d.mutation.done = true
+ return affected, err
+}
+
+// ChannelMonitorRequestTemplateDeleteOne is the builder for deleting a single ChannelMonitorRequestTemplate entity.
+type ChannelMonitorRequestTemplateDeleteOne struct {
+ _d *ChannelMonitorRequestTemplateDelete
+}
+
+// Where appends a list predicates to the ChannelMonitorRequestTemplateDelete builder.
+func (_d *ChannelMonitorRequestTemplateDeleteOne) Where(ps ...predicate.ChannelMonitorRequestTemplate) *ChannelMonitorRequestTemplateDeleteOne {
+ _d._d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query.
+func (_d *ChannelMonitorRequestTemplateDeleteOne) Exec(ctx context.Context) error {
+ n, err := _d._d.Exec(ctx)
+ switch {
+ case err != nil:
+ return err
+ case n == 0:
+ return &NotFoundError{channelmonitorrequesttemplate.Label}
+ default:
+ return nil
+ }
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *ChannelMonitorRequestTemplateDeleteOne) ExecX(ctx context.Context) {
+ if err := _d.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/channelmonitorrequesttemplate_query.go b/backend/ent/channelmonitorrequesttemplate_query.go
new file mode 100644
index 00000000..6491ea60
--- /dev/null
+++ b/backend/ent/channelmonitorrequesttemplate_query.go
@@ -0,0 +1,648 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "database/sql/driver"
+ "fmt"
+ "math"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ChannelMonitorRequestTemplateQuery is the builder for querying ChannelMonitorRequestTemplate entities.
+type ChannelMonitorRequestTemplateQuery struct {
+ config
+ ctx *QueryContext
+ order []channelmonitorrequesttemplate.OrderOption
+ inters []Interceptor
+ predicates []predicate.ChannelMonitorRequestTemplate
+ withMonitors *ChannelMonitorQuery
+ modifiers []func(*sql.Selector)
+ // intermediate query (i.e. traversal path).
+ sql *sql.Selector
+ path func(context.Context) (*sql.Selector, error)
+}
+
+// Where adds a new predicate for the ChannelMonitorRequestTemplateQuery builder.
+func (_q *ChannelMonitorRequestTemplateQuery) Where(ps ...predicate.ChannelMonitorRequestTemplate) *ChannelMonitorRequestTemplateQuery {
+ _q.predicates = append(_q.predicates, ps...)
+ return _q
+}
+
+// Limit the number of records to be returned by this query.
+func (_q *ChannelMonitorRequestTemplateQuery) Limit(limit int) *ChannelMonitorRequestTemplateQuery {
+ _q.ctx.Limit = &limit
+ return _q
+}
+
+// Offset to start from.
+func (_q *ChannelMonitorRequestTemplateQuery) Offset(offset int) *ChannelMonitorRequestTemplateQuery {
+ _q.ctx.Offset = &offset
+ return _q
+}
+
+// Unique configures the query builder to filter duplicate records on query.
+// By default, unique is set to true, and can be disabled using this method.
+func (_q *ChannelMonitorRequestTemplateQuery) Unique(unique bool) *ChannelMonitorRequestTemplateQuery {
+ _q.ctx.Unique = &unique
+ return _q
+}
+
+// Order specifies how the records should be ordered.
+func (_q *ChannelMonitorRequestTemplateQuery) Order(o ...channelmonitorrequesttemplate.OrderOption) *ChannelMonitorRequestTemplateQuery {
+ _q.order = append(_q.order, o...)
+ return _q
+}
+
+// QueryMonitors chains the current query on the "monitors" edge.
+func (_q *ChannelMonitorRequestTemplateQuery) QueryMonitors() *ChannelMonitorQuery {
+ query := (&ChannelMonitorClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(channelmonitorrequesttemplate.Table, channelmonitorrequesttemplate.FieldID, selector),
+ sqlgraph.To(channelmonitor.Table, channelmonitor.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, true, channelmonitorrequesttemplate.MonitorsTable, channelmonitorrequesttemplate.MonitorsColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// First returns the first ChannelMonitorRequestTemplate entity from the query.
+// Returns a *NotFoundError when no ChannelMonitorRequestTemplate was found.
+func (_q *ChannelMonitorRequestTemplateQuery) First(ctx context.Context) (*ChannelMonitorRequestTemplate, error) {
+ nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
+ if err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nil, &NotFoundError{channelmonitorrequesttemplate.Label}
+ }
+ return nodes[0], nil
+}
+
+// FirstX is like First, but panics if an error occurs.
+func (_q *ChannelMonitorRequestTemplateQuery) FirstX(ctx context.Context) *ChannelMonitorRequestTemplate {
+ node, err := _q.First(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return node
+}
+
+// FirstID returns the first ChannelMonitorRequestTemplate ID from the query.
+// Returns a *NotFoundError when no ChannelMonitorRequestTemplate ID was found.
+func (_q *ChannelMonitorRequestTemplateQuery) FirstID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
+ return
+ }
+ if len(ids) == 0 {
+ err = &NotFoundError{channelmonitorrequesttemplate.Label}
+ return
+ }
+ return ids[0], nil
+}
+
+// FirstIDX is like FirstID, but panics if an error occurs.
+func (_q *ChannelMonitorRequestTemplateQuery) FirstIDX(ctx context.Context) int64 {
+ id, err := _q.FirstID(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return id
+}
+
+// Only returns a single ChannelMonitorRequestTemplate entity found by the query, ensuring it only returns one.
+// Returns a *NotSingularError when more than one ChannelMonitorRequestTemplate entity is found.
+// Returns a *NotFoundError when no ChannelMonitorRequestTemplate entities are found.
+func (_q *ChannelMonitorRequestTemplateQuery) Only(ctx context.Context) (*ChannelMonitorRequestTemplate, error) {
+ nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
+ if err != nil {
+ return nil, err
+ }
+ switch len(nodes) {
+ case 1:
+ return nodes[0], nil
+ case 0:
+ return nil, &NotFoundError{channelmonitorrequesttemplate.Label}
+ default:
+ return nil, &NotSingularError{channelmonitorrequesttemplate.Label}
+ }
+}
+
+// OnlyX is like Only, but panics if an error occurs.
+func (_q *ChannelMonitorRequestTemplateQuery) OnlyX(ctx context.Context) *ChannelMonitorRequestTemplate {
+ node, err := _q.Only(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// OnlyID is like Only, but returns the only ChannelMonitorRequestTemplate ID in the query.
+// Returns a *NotSingularError when more than one ChannelMonitorRequestTemplate ID is found.
+// Returns a *NotFoundError when no entities are found.
+func (_q *ChannelMonitorRequestTemplateQuery) OnlyID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
+ return
+ }
+ switch len(ids) {
+ case 1:
+ id = ids[0]
+ case 0:
+ err = &NotFoundError{channelmonitorrequesttemplate.Label}
+ default:
+ err = &NotSingularError{channelmonitorrequesttemplate.Label}
+ }
+ return
+}
+
+// OnlyIDX is like OnlyID, but panics if an error occurs.
+func (_q *ChannelMonitorRequestTemplateQuery) OnlyIDX(ctx context.Context) int64 {
+ id, err := _q.OnlyID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// All executes the query and returns a list of ChannelMonitorRequestTemplates.
+func (_q *ChannelMonitorRequestTemplateQuery) All(ctx context.Context) ([]*ChannelMonitorRequestTemplate, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ qr := querierAll[[]*ChannelMonitorRequestTemplate, *ChannelMonitorRequestTemplateQuery]()
+ return withInterceptors[[]*ChannelMonitorRequestTemplate](ctx, _q, qr, _q.inters)
+}
+
+// AllX is like All, but panics if an error occurs.
+func (_q *ChannelMonitorRequestTemplateQuery) AllX(ctx context.Context) []*ChannelMonitorRequestTemplate {
+ nodes, err := _q.All(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return nodes
+}
+
+// IDs executes the query and returns a list of ChannelMonitorRequestTemplate IDs.
+func (_q *ChannelMonitorRequestTemplateQuery) IDs(ctx context.Context) (ids []int64, err error) {
+ if _q.ctx.Unique == nil && _q.path != nil {
+ _q.Unique(true)
+ }
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
+ if err = _q.Select(channelmonitorrequesttemplate.FieldID).Scan(ctx, &ids); err != nil {
+ return nil, err
+ }
+ return ids, nil
+}
+
+// IDsX is like IDs, but panics if an error occurs.
+func (_q *ChannelMonitorRequestTemplateQuery) IDsX(ctx context.Context) []int64 {
+ ids, err := _q.IDs(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return ids
+}
+
+// Count returns the count of the given query.
+func (_q *ChannelMonitorRequestTemplateQuery) Count(ctx context.Context) (int, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return 0, err
+ }
+ return withInterceptors[int](ctx, _q, querierCount[*ChannelMonitorRequestTemplateQuery](), _q.inters)
+}
+
+// CountX is like Count, but panics if an error occurs.
+func (_q *ChannelMonitorRequestTemplateQuery) CountX(ctx context.Context) int {
+ count, err := _q.Count(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return count
+}
+
+// Exist returns true if the query has elements in the graph.
+func (_q *ChannelMonitorRequestTemplateQuery) Exist(ctx context.Context) (bool, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
+ switch _, err := _q.FirstID(ctx); {
+ case IsNotFound(err):
+ return false, nil
+ case err != nil:
+ return false, fmt.Errorf("ent: check existence: %w", err)
+ default:
+ return true, nil
+ }
+}
+
+// ExistX is like Exist, but panics if an error occurs.
+func (_q *ChannelMonitorRequestTemplateQuery) ExistX(ctx context.Context) bool {
+ exist, err := _q.Exist(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return exist
+}
+
+// Clone returns a duplicate of the ChannelMonitorRequestTemplateQuery builder, including all associated steps. It can be
+// used to prepare common query builders and use them differently after the clone is made.
+func (_q *ChannelMonitorRequestTemplateQuery) Clone() *ChannelMonitorRequestTemplateQuery {
+ if _q == nil {
+ return nil
+ }
+ return &ChannelMonitorRequestTemplateQuery{
+ config: _q.config,
+ ctx: _q.ctx.Clone(),
+ order: append([]channelmonitorrequesttemplate.OrderOption{}, _q.order...),
+ inters: append([]Interceptor{}, _q.inters...),
+ predicates: append([]predicate.ChannelMonitorRequestTemplate{}, _q.predicates...),
+ withMonitors: _q.withMonitors.Clone(),
+ // clone intermediate query.
+ sql: _q.sql.Clone(),
+ path: _q.path,
+ }
+}
+
+// WithMonitors tells the query-builder to eager-load the nodes that are connected to
+// the "monitors" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *ChannelMonitorRequestTemplateQuery) WithMonitors(opts ...func(*ChannelMonitorQuery)) *ChannelMonitorRequestTemplateQuery {
+ query := (&ChannelMonitorClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withMonitors = query
+ return _q
+}
+
+// GroupBy is used to group vertices by one or more fields/columns.
+// It is often used with aggregate functions, like: count, max, mean, min, sum.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// Count int `json:"count,omitempty"`
+// }
+//
+// client.ChannelMonitorRequestTemplate.Query().
+// GroupBy(channelmonitorrequesttemplate.FieldCreatedAt).
+// Aggregate(ent.Count()).
+// Scan(ctx, &v)
+func (_q *ChannelMonitorRequestTemplateQuery) GroupBy(field string, fields ...string) *ChannelMonitorRequestTemplateGroupBy {
+ _q.ctx.Fields = append([]string{field}, fields...)
+ grbuild := &ChannelMonitorRequestTemplateGroupBy{build: _q}
+ grbuild.flds = &_q.ctx.Fields
+ grbuild.label = channelmonitorrequesttemplate.Label
+ grbuild.scan = grbuild.Scan
+ return grbuild
+}
+
+// Select allows the selection one or more fields/columns for the given query,
+// instead of selecting all fields in the entity.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// }
+//
+// client.ChannelMonitorRequestTemplate.Query().
+// Select(channelmonitorrequesttemplate.FieldCreatedAt).
+// Scan(ctx, &v)
+func (_q *ChannelMonitorRequestTemplateQuery) Select(fields ...string) *ChannelMonitorRequestTemplateSelect {
+ _q.ctx.Fields = append(_q.ctx.Fields, fields...)
+ sbuild := &ChannelMonitorRequestTemplateSelect{ChannelMonitorRequestTemplateQuery: _q}
+ sbuild.label = channelmonitorrequesttemplate.Label
+ sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
+ return sbuild
+}
+
+// Aggregate returns a ChannelMonitorRequestTemplateSelect configured with the given aggregations.
+func (_q *ChannelMonitorRequestTemplateQuery) Aggregate(fns ...AggregateFunc) *ChannelMonitorRequestTemplateSelect {
+ return _q.Select().Aggregate(fns...)
+}
+
+func (_q *ChannelMonitorRequestTemplateQuery) prepareQuery(ctx context.Context) error {
+ for _, inter := range _q.inters {
+ if inter == nil {
+ return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
+ }
+ if trv, ok := inter.(Traverser); ok {
+ if err := trv.Traverse(ctx, _q); err != nil {
+ return err
+ }
+ }
+ }
+ for _, f := range _q.ctx.Fields {
+ if !channelmonitorrequesttemplate.ValidColumn(f) {
+ return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ }
+ if _q.path != nil {
+ prev, err := _q.path(ctx)
+ if err != nil {
+ return err
+ }
+ _q.sql = prev
+ }
+ return nil
+}
+
+func (_q *ChannelMonitorRequestTemplateQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*ChannelMonitorRequestTemplate, error) {
+ var (
+ nodes = []*ChannelMonitorRequestTemplate{}
+ _spec = _q.querySpec()
+ loadedTypes = [1]bool{
+ _q.withMonitors != nil,
+ }
+ )
+ _spec.ScanValues = func(columns []string) ([]any, error) {
+ return (*ChannelMonitorRequestTemplate).scanValues(nil, columns)
+ }
+ _spec.Assign = func(columns []string, values []any) error {
+ node := &ChannelMonitorRequestTemplate{config: _q.config}
+ nodes = append(nodes, node)
+ node.Edges.loadedTypes = loadedTypes
+ return node.assignValues(columns, values)
+ }
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ for i := range hooks {
+ hooks[i](ctx, _spec)
+ }
+ if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nodes, nil
+ }
+ if query := _q.withMonitors; query != nil {
+ if err := _q.loadMonitors(ctx, query, nodes,
+ func(n *ChannelMonitorRequestTemplate) { n.Edges.Monitors = []*ChannelMonitor{} },
+ func(n *ChannelMonitorRequestTemplate, e *ChannelMonitor) {
+ n.Edges.Monitors = append(n.Edges.Monitors, e)
+ }); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+func (_q *ChannelMonitorRequestTemplateQuery) loadMonitors(ctx context.Context, query *ChannelMonitorQuery, nodes []*ChannelMonitorRequestTemplate, init func(*ChannelMonitorRequestTemplate), assign func(*ChannelMonitorRequestTemplate, *ChannelMonitor)) error {
+ fks := make([]driver.Value, 0, len(nodes))
+ nodeids := make(map[int64]*ChannelMonitorRequestTemplate)
+ for i := range nodes {
+ fks = append(fks, nodes[i].ID)
+ nodeids[nodes[i].ID] = nodes[i]
+ if init != nil {
+ init(nodes[i])
+ }
+ }
+ if len(query.ctx.Fields) > 0 {
+ query.ctx.AppendFieldOnce(channelmonitor.FieldTemplateID)
+ }
+ query.Where(predicate.ChannelMonitor(func(s *sql.Selector) {
+ s.Where(sql.InValues(s.C(channelmonitorrequesttemplate.MonitorsColumn), fks...))
+ }))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ fk := n.TemplateID
+ if fk == nil {
+ return fmt.Errorf(`foreign-key "template_id" is nil for node %v`, n.ID)
+ }
+ node, ok := nodeids[*fk]
+ if !ok {
+ return fmt.Errorf(`unexpected referenced foreign-key "template_id" returned %v for node %v`, *fk, n.ID)
+ }
+ assign(node, n)
+ }
+ return nil
+}
+
+func (_q *ChannelMonitorRequestTemplateQuery) sqlCount(ctx context.Context) (int, error) {
+ _spec := _q.querySpec()
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ _spec.Node.Columns = _q.ctx.Fields
+ if len(_q.ctx.Fields) > 0 {
+ _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
+ }
+ return sqlgraph.CountNodes(ctx, _q.driver, _spec)
+}
+
+func (_q *ChannelMonitorRequestTemplateQuery) querySpec() *sqlgraph.QuerySpec {
+ _spec := sqlgraph.NewQuerySpec(channelmonitorrequesttemplate.Table, channelmonitorrequesttemplate.Columns, sqlgraph.NewFieldSpec(channelmonitorrequesttemplate.FieldID, field.TypeInt64))
+ _spec.From = _q.sql
+ if unique := _q.ctx.Unique; unique != nil {
+ _spec.Unique = *unique
+ } else if _q.path != nil {
+ _spec.Unique = true
+ }
+ if fields := _q.ctx.Fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, channelmonitorrequesttemplate.FieldID)
+ for i := range fields {
+ if fields[i] != channelmonitorrequesttemplate.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, fields[i])
+ }
+ }
+ }
+ if ps := _q.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ _spec.Limit = *limit
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ _spec.Offset = *offset
+ }
+ if ps := _q.order; len(ps) > 0 {
+ _spec.Order = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ return _spec
+}
+
+func (_q *ChannelMonitorRequestTemplateQuery) sqlQuery(ctx context.Context) *sql.Selector {
+ builder := sql.Dialect(_q.driver.Dialect())
+ t1 := builder.Table(channelmonitorrequesttemplate.Table)
+ columns := _q.ctx.Fields
+ if len(columns) == 0 {
+ columns = channelmonitorrequesttemplate.Columns
+ }
+ selector := builder.Select(t1.Columns(columns...)...).From(t1)
+ if _q.sql != nil {
+ selector = _q.sql
+ selector.Select(selector.Columns(columns...)...)
+ }
+ if _q.ctx.Unique != nil && *_q.ctx.Unique {
+ selector.Distinct()
+ }
+ for _, m := range _q.modifiers {
+ m(selector)
+ }
+ for _, p := range _q.predicates {
+ p(selector)
+ }
+ for _, p := range _q.order {
+ p(selector)
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ // limit is mandatory for offset clause. We start
+ // with default value, and override it below if needed.
+ selector.Offset(*offset).Limit(math.MaxInt32)
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ selector.Limit(*limit)
+ }
+ return selector
+}
+
+// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
+// updated, deleted or "selected ... for update" by other sessions, until the transaction is
+// either committed or rolled-back.
+func (_q *ChannelMonitorRequestTemplateQuery) ForUpdate(opts ...sql.LockOption) *ChannelMonitorRequestTemplateQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForUpdate(opts...)
+ })
+ return _q
+}
+
+// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
+// on any rows that are read. Other sessions can read the rows, but cannot modify them
+// until your transaction commits.
+func (_q *ChannelMonitorRequestTemplateQuery) ForShare(opts ...sql.LockOption) *ChannelMonitorRequestTemplateQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForShare(opts...)
+ })
+ return _q
+}
+
+// ChannelMonitorRequestTemplateGroupBy is the group-by builder for ChannelMonitorRequestTemplate entities.
+type ChannelMonitorRequestTemplateGroupBy struct {
+ selector
+ build *ChannelMonitorRequestTemplateQuery
+}
+
+// Aggregate adds the given aggregation functions to the group-by query.
+func (_g *ChannelMonitorRequestTemplateGroupBy) Aggregate(fns ...AggregateFunc) *ChannelMonitorRequestTemplateGroupBy {
+ _g.fns = append(_g.fns, fns...)
+ return _g
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_g *ChannelMonitorRequestTemplateGroupBy) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
+ if err := _g.build.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*ChannelMonitorRequestTemplateQuery, *ChannelMonitorRequestTemplateGroupBy](ctx, _g.build, _g, _g.build.inters, v)
+}
+
+func (_g *ChannelMonitorRequestTemplateGroupBy) sqlScan(ctx context.Context, root *ChannelMonitorRequestTemplateQuery, v any) error {
+ selector := root.sqlQuery(ctx).Select()
+ aggregation := make([]string, 0, len(_g.fns))
+ for _, fn := range _g.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ if len(selector.SelectedColumns()) == 0 {
+ columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
+ for _, f := range *_g.flds {
+ columns = append(columns, selector.C(f))
+ }
+ columns = append(columns, aggregation...)
+ selector.Select(columns...)
+ }
+ selector.GroupBy(selector.Columns(*_g.flds...)...)
+ if err := selector.Err(); err != nil {
+ return err
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
+
+// ChannelMonitorRequestTemplateSelect is the builder for selecting fields of ChannelMonitorRequestTemplate entities.
+type ChannelMonitorRequestTemplateSelect struct {
+ *ChannelMonitorRequestTemplateQuery
+ selector
+}
+
+// Aggregate adds the given aggregation functions to the selector query.
+func (_s *ChannelMonitorRequestTemplateSelect) Aggregate(fns ...AggregateFunc) *ChannelMonitorRequestTemplateSelect {
+ _s.fns = append(_s.fns, fns...)
+ return _s
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_s *ChannelMonitorRequestTemplateSelect) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
+ if err := _s.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*ChannelMonitorRequestTemplateQuery, *ChannelMonitorRequestTemplateSelect](ctx, _s.ChannelMonitorRequestTemplateQuery, _s, _s.inters, v)
+}
+
+func (_s *ChannelMonitorRequestTemplateSelect) sqlScan(ctx context.Context, root *ChannelMonitorRequestTemplateQuery, v any) error {
+ selector := root.sqlQuery(ctx)
+ aggregation := make([]string, 0, len(_s.fns))
+ for _, fn := range _s.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ switch n := len(*_s.selector.flds); {
+ case n == 0 && len(aggregation) > 0:
+ selector.Select(aggregation...)
+ case n != 0 && len(aggregation) > 0:
+ selector.AppendSelect(aggregation...)
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _s.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
diff --git a/backend/ent/channelmonitorrequesttemplate_update.go b/backend/ent/channelmonitorrequesttemplate_update.go
new file mode 100644
index 00000000..8f55ba04
--- /dev/null
+++ b/backend/ent/channelmonitorrequesttemplate_update.go
@@ -0,0 +1,639 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ChannelMonitorRequestTemplateUpdate is the builder for updating ChannelMonitorRequestTemplate entities.
+type ChannelMonitorRequestTemplateUpdate struct {
+ config
+ hooks []Hook
+ mutation *ChannelMonitorRequestTemplateMutation
+}
+
+// Where appends a list predicates to the ChannelMonitorRequestTemplateUpdate builder.
+func (_u *ChannelMonitorRequestTemplateUpdate) Where(ps ...predicate.ChannelMonitorRequestTemplate) *ChannelMonitorRequestTemplateUpdate {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *ChannelMonitorRequestTemplateUpdate) SetUpdatedAt(v time.Time) *ChannelMonitorRequestTemplateUpdate {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetName sets the "name" field.
+func (_u *ChannelMonitorRequestTemplateUpdate) SetName(v string) *ChannelMonitorRequestTemplateUpdate {
+ _u.mutation.SetName(v)
+ return _u
+}
+
+// SetNillableName sets the "name" field if the given value is not nil.
+func (_u *ChannelMonitorRequestTemplateUpdate) SetNillableName(v *string) *ChannelMonitorRequestTemplateUpdate {
+ if v != nil {
+ _u.SetName(*v)
+ }
+ return _u
+}
+
+// SetProvider sets the "provider" field.
+func (_u *ChannelMonitorRequestTemplateUpdate) SetProvider(v channelmonitorrequesttemplate.Provider) *ChannelMonitorRequestTemplateUpdate {
+ _u.mutation.SetProvider(v)
+ return _u
+}
+
+// SetNillableProvider sets the "provider" field if the given value is not nil.
+func (_u *ChannelMonitorRequestTemplateUpdate) SetNillableProvider(v *channelmonitorrequesttemplate.Provider) *ChannelMonitorRequestTemplateUpdate {
+ if v != nil {
+ _u.SetProvider(*v)
+ }
+ return _u
+}
+
+// SetDescription sets the "description" field.
+func (_u *ChannelMonitorRequestTemplateUpdate) SetDescription(v string) *ChannelMonitorRequestTemplateUpdate {
+ _u.mutation.SetDescription(v)
+ return _u
+}
+
+// SetNillableDescription sets the "description" field if the given value is not nil.
+func (_u *ChannelMonitorRequestTemplateUpdate) SetNillableDescription(v *string) *ChannelMonitorRequestTemplateUpdate {
+ if v != nil {
+ _u.SetDescription(*v)
+ }
+ return _u
+}
+
+// ClearDescription clears the value of the "description" field.
+func (_u *ChannelMonitorRequestTemplateUpdate) ClearDescription() *ChannelMonitorRequestTemplateUpdate {
+ _u.mutation.ClearDescription()
+ return _u
+}
+
+// SetExtraHeaders sets the "extra_headers" field.
+func (_u *ChannelMonitorRequestTemplateUpdate) SetExtraHeaders(v map[string]string) *ChannelMonitorRequestTemplateUpdate {
+ _u.mutation.SetExtraHeaders(v)
+ return _u
+}
+
+// SetBodyOverrideMode sets the "body_override_mode" field.
+func (_u *ChannelMonitorRequestTemplateUpdate) SetBodyOverrideMode(v string) *ChannelMonitorRequestTemplateUpdate {
+ _u.mutation.SetBodyOverrideMode(v)
+ return _u
+}
+
+// SetNillableBodyOverrideMode sets the "body_override_mode" field if the given value is not nil.
+func (_u *ChannelMonitorRequestTemplateUpdate) SetNillableBodyOverrideMode(v *string) *ChannelMonitorRequestTemplateUpdate {
+ if v != nil {
+ _u.SetBodyOverrideMode(*v)
+ }
+ return _u
+}
+
+// SetBodyOverride sets the "body_override" field.
+func (_u *ChannelMonitorRequestTemplateUpdate) SetBodyOverride(v map[string]interface{}) *ChannelMonitorRequestTemplateUpdate {
+ _u.mutation.SetBodyOverride(v)
+ return _u
+}
+
+// ClearBodyOverride clears the value of the "body_override" field.
+func (_u *ChannelMonitorRequestTemplateUpdate) ClearBodyOverride() *ChannelMonitorRequestTemplateUpdate {
+ _u.mutation.ClearBodyOverride()
+ return _u
+}
+
+// AddMonitorIDs adds the "monitors" edge to the ChannelMonitor entity by IDs.
+func (_u *ChannelMonitorRequestTemplateUpdate) AddMonitorIDs(ids ...int64) *ChannelMonitorRequestTemplateUpdate {
+ _u.mutation.AddMonitorIDs(ids...)
+ return _u
+}
+
+// AddMonitors adds the "monitors" edges to the ChannelMonitor entity.
+func (_u *ChannelMonitorRequestTemplateUpdate) AddMonitors(v ...*ChannelMonitor) *ChannelMonitorRequestTemplateUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddMonitorIDs(ids...)
+}
+
+// Mutation returns the ChannelMonitorRequestTemplateMutation object of the builder.
+func (_u *ChannelMonitorRequestTemplateUpdate) Mutation() *ChannelMonitorRequestTemplateMutation {
+ return _u.mutation
+}
+
+// ClearMonitors clears all "monitors" edges to the ChannelMonitor entity.
+func (_u *ChannelMonitorRequestTemplateUpdate) ClearMonitors() *ChannelMonitorRequestTemplateUpdate {
+ _u.mutation.ClearMonitors()
+ return _u
+}
+
+// RemoveMonitorIDs removes the "monitors" edge to ChannelMonitor entities by IDs.
+func (_u *ChannelMonitorRequestTemplateUpdate) RemoveMonitorIDs(ids ...int64) *ChannelMonitorRequestTemplateUpdate {
+ _u.mutation.RemoveMonitorIDs(ids...)
+ return _u
+}
+
+// RemoveMonitors removes "monitors" edges to ChannelMonitor entities.
+func (_u *ChannelMonitorRequestTemplateUpdate) RemoveMonitors(v ...*ChannelMonitor) *ChannelMonitorRequestTemplateUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveMonitorIDs(ids...)
+}
+
+// Save executes the query and returns the number of nodes affected by the update operation.
+func (_u *ChannelMonitorRequestTemplateUpdate) Save(ctx context.Context) (int, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *ChannelMonitorRequestTemplateUpdate) SaveX(ctx context.Context) int {
+ affected, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return affected
+}
+
+// Exec executes the query.
+func (_u *ChannelMonitorRequestTemplateUpdate) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *ChannelMonitorRequestTemplateUpdate) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *ChannelMonitorRequestTemplateUpdate) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := channelmonitorrequesttemplate.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *ChannelMonitorRequestTemplateUpdate) check() error {
+ if v, ok := _u.mutation.Name(); ok {
+ if err := channelmonitorrequesttemplate.NameValidator(v); err != nil {
+ return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.name": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Provider(); ok {
+ if err := channelmonitorrequesttemplate.ProviderValidator(v); err != nil {
+ return &ValidationError{Name: "provider", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.provider": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Description(); ok {
+ if err := channelmonitorrequesttemplate.DescriptionValidator(v); err != nil {
+ return &ValidationError{Name: "description", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.description": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.BodyOverrideMode(); ok {
+ if err := channelmonitorrequesttemplate.BodyOverrideModeValidator(v); err != nil {
+ return &ValidationError{Name: "body_override_mode", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.body_override_mode": %w`, err)}
+ }
+ }
+ return nil
+}
+
+func (_u *ChannelMonitorRequestTemplateUpdate) sqlSave(ctx context.Context) (_node int, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(channelmonitorrequesttemplate.Table, channelmonitorrequesttemplate.Columns, sqlgraph.NewFieldSpec(channelmonitorrequesttemplate.FieldID, field.TypeInt64))
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.Name(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldName, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Provider(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldProvider, field.TypeEnum, value)
+ }
+ if value, ok := _u.mutation.Description(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldDescription, field.TypeString, value)
+ }
+ if _u.mutation.DescriptionCleared() {
+ _spec.ClearField(channelmonitorrequesttemplate.FieldDescription, field.TypeString)
+ }
+ if value, ok := _u.mutation.ExtraHeaders(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldExtraHeaders, field.TypeJSON, value)
+ }
+ if value, ok := _u.mutation.BodyOverrideMode(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldBodyOverrideMode, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.BodyOverride(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldBodyOverride, field.TypeJSON, value)
+ }
+ if _u.mutation.BodyOverrideCleared() {
+ _spec.ClearField(channelmonitorrequesttemplate.FieldBodyOverride, field.TypeJSON)
+ }
+ if _u.mutation.MonitorsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: true,
+ Table: channelmonitorrequesttemplate.MonitorsTable,
+ Columns: []string{channelmonitorrequesttemplate.MonitorsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedMonitorsIDs(); len(nodes) > 0 && !_u.mutation.MonitorsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: true,
+ Table: channelmonitorrequesttemplate.MonitorsTable,
+ Columns: []string{channelmonitorrequesttemplate.MonitorsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.MonitorsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: true,
+ Table: channelmonitorrequesttemplate.MonitorsTable,
+ Columns: []string{channelmonitorrequesttemplate.MonitorsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{channelmonitorrequesttemplate.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return 0, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
+
+// ChannelMonitorRequestTemplateUpdateOne is the builder for updating a single ChannelMonitorRequestTemplate entity.
+type ChannelMonitorRequestTemplateUpdateOne struct {
+ config
+ fields []string
+ hooks []Hook
+ mutation *ChannelMonitorRequestTemplateMutation
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) SetUpdatedAt(v time.Time) *ChannelMonitorRequestTemplateUpdateOne {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetName sets the "name" field.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) SetName(v string) *ChannelMonitorRequestTemplateUpdateOne {
+ _u.mutation.SetName(v)
+ return _u
+}
+
+// SetNillableName sets the "name" field if the given value is not nil.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) SetNillableName(v *string) *ChannelMonitorRequestTemplateUpdateOne {
+ if v != nil {
+ _u.SetName(*v)
+ }
+ return _u
+}
+
+// SetProvider sets the "provider" field.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) SetProvider(v channelmonitorrequesttemplate.Provider) *ChannelMonitorRequestTemplateUpdateOne {
+ _u.mutation.SetProvider(v)
+ return _u
+}
+
+// SetNillableProvider sets the "provider" field if the given value is not nil.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) SetNillableProvider(v *channelmonitorrequesttemplate.Provider) *ChannelMonitorRequestTemplateUpdateOne {
+ if v != nil {
+ _u.SetProvider(*v)
+ }
+ return _u
+}
+
+// SetDescription sets the "description" field.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) SetDescription(v string) *ChannelMonitorRequestTemplateUpdateOne {
+ _u.mutation.SetDescription(v)
+ return _u
+}
+
+// SetNillableDescription sets the "description" field if the given value is not nil.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) SetNillableDescription(v *string) *ChannelMonitorRequestTemplateUpdateOne {
+ if v != nil {
+ _u.SetDescription(*v)
+ }
+ return _u
+}
+
+// ClearDescription clears the value of the "description" field.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) ClearDescription() *ChannelMonitorRequestTemplateUpdateOne {
+ _u.mutation.ClearDescription()
+ return _u
+}
+
+// SetExtraHeaders sets the "extra_headers" field.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) SetExtraHeaders(v map[string]string) *ChannelMonitorRequestTemplateUpdateOne {
+ _u.mutation.SetExtraHeaders(v)
+ return _u
+}
+
+// SetBodyOverrideMode sets the "body_override_mode" field.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) SetBodyOverrideMode(v string) *ChannelMonitorRequestTemplateUpdateOne {
+ _u.mutation.SetBodyOverrideMode(v)
+ return _u
+}
+
+// SetNillableBodyOverrideMode sets the "body_override_mode" field if the given value is not nil.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) SetNillableBodyOverrideMode(v *string) *ChannelMonitorRequestTemplateUpdateOne {
+ if v != nil {
+ _u.SetBodyOverrideMode(*v)
+ }
+ return _u
+}
+
+// SetBodyOverride sets the "body_override" field.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) SetBodyOverride(v map[string]interface{}) *ChannelMonitorRequestTemplateUpdateOne {
+ _u.mutation.SetBodyOverride(v)
+ return _u
+}
+
+// ClearBodyOverride clears the value of the "body_override" field.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) ClearBodyOverride() *ChannelMonitorRequestTemplateUpdateOne {
+ _u.mutation.ClearBodyOverride()
+ return _u
+}
+
+// AddMonitorIDs adds the "monitors" edge to the ChannelMonitor entity by IDs.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) AddMonitorIDs(ids ...int64) *ChannelMonitorRequestTemplateUpdateOne {
+ _u.mutation.AddMonitorIDs(ids...)
+ return _u
+}
+
+// AddMonitors adds the "monitors" edges to the ChannelMonitor entity.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) AddMonitors(v ...*ChannelMonitor) *ChannelMonitorRequestTemplateUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddMonitorIDs(ids...)
+}
+
+// Mutation returns the ChannelMonitorRequestTemplateMutation object of the builder.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) Mutation() *ChannelMonitorRequestTemplateMutation {
+ return _u.mutation
+}
+
+// ClearMonitors clears all "monitors" edges to the ChannelMonitor entity.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) ClearMonitors() *ChannelMonitorRequestTemplateUpdateOne {
+ _u.mutation.ClearMonitors()
+ return _u
+}
+
+// RemoveMonitorIDs removes the "monitors" edge to ChannelMonitor entities by IDs.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) RemoveMonitorIDs(ids ...int64) *ChannelMonitorRequestTemplateUpdateOne {
+ _u.mutation.RemoveMonitorIDs(ids...)
+ return _u
+}
+
+// RemoveMonitors removes "monitors" edges to ChannelMonitor entities.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) RemoveMonitors(v ...*ChannelMonitor) *ChannelMonitorRequestTemplateUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveMonitorIDs(ids...)
+}
+
+// Where appends a list predicates to the ChannelMonitorRequestTemplateUpdate builder.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) Where(ps ...predicate.ChannelMonitorRequestTemplate) *ChannelMonitorRequestTemplateUpdateOne {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// Select allows selecting one or more fields (columns) of the returned entity.
+// The default is selecting all fields defined in the entity schema.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) Select(field string, fields ...string) *ChannelMonitorRequestTemplateUpdateOne {
+ _u.fields = append([]string{field}, fields...)
+ return _u
+}
+
+// Save executes the query and returns the updated ChannelMonitorRequestTemplate entity.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) Save(ctx context.Context) (*ChannelMonitorRequestTemplate, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) SaveX(ctx context.Context) *ChannelMonitorRequestTemplate {
+ node, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// Exec executes the query on the entity.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := channelmonitorrequesttemplate.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) check() error {
+ if v, ok := _u.mutation.Name(); ok {
+ if err := channelmonitorrequesttemplate.NameValidator(v); err != nil {
+ return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.name": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Provider(); ok {
+ if err := channelmonitorrequesttemplate.ProviderValidator(v); err != nil {
+ return &ValidationError{Name: "provider", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.provider": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Description(); ok {
+ if err := channelmonitorrequesttemplate.DescriptionValidator(v); err != nil {
+ return &ValidationError{Name: "description", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.description": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.BodyOverrideMode(); ok {
+ if err := channelmonitorrequesttemplate.BodyOverrideModeValidator(v); err != nil {
+ return &ValidationError{Name: "body_override_mode", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.body_override_mode": %w`, err)}
+ }
+ }
+ return nil
+}
+
+func (_u *ChannelMonitorRequestTemplateUpdateOne) sqlSave(ctx context.Context) (_node *ChannelMonitorRequestTemplate, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(channelmonitorrequesttemplate.Table, channelmonitorrequesttemplate.Columns, sqlgraph.NewFieldSpec(channelmonitorrequesttemplate.FieldID, field.TypeInt64))
+ id, ok := _u.mutation.ID()
+ if !ok {
+ return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "ChannelMonitorRequestTemplate.id" for update`)}
+ }
+ _spec.Node.ID.Value = id
+ if fields := _u.fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, channelmonitorrequesttemplate.FieldID)
+ for _, f := range fields {
+ if !channelmonitorrequesttemplate.ValidColumn(f) {
+ return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ if f != channelmonitorrequesttemplate.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, f)
+ }
+ }
+ }
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.Name(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldName, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Provider(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldProvider, field.TypeEnum, value)
+ }
+ if value, ok := _u.mutation.Description(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldDescription, field.TypeString, value)
+ }
+ if _u.mutation.DescriptionCleared() {
+ _spec.ClearField(channelmonitorrequesttemplate.FieldDescription, field.TypeString)
+ }
+ if value, ok := _u.mutation.ExtraHeaders(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldExtraHeaders, field.TypeJSON, value)
+ }
+ if value, ok := _u.mutation.BodyOverrideMode(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldBodyOverrideMode, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.BodyOverride(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldBodyOverride, field.TypeJSON, value)
+ }
+ if _u.mutation.BodyOverrideCleared() {
+ _spec.ClearField(channelmonitorrequesttemplate.FieldBodyOverride, field.TypeJSON)
+ }
+ if _u.mutation.MonitorsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: true,
+ Table: channelmonitorrequesttemplate.MonitorsTable,
+ Columns: []string{channelmonitorrequesttemplate.MonitorsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedMonitorsIDs(); len(nodes) > 0 && !_u.mutation.MonitorsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: true,
+ Table: channelmonitorrequesttemplate.MonitorsTable,
+ Columns: []string{channelmonitorrequesttemplate.MonitorsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.MonitorsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: true,
+ Table: channelmonitorrequesttemplate.MonitorsTable,
+ Columns: []string{channelmonitorrequesttemplate.MonitorsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ _node = &ChannelMonitorRequestTemplate{config: _u.config}
+ _spec.Assign = _node.assignValues
+ _spec.ScanValues = _node.scanValues
+ if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{channelmonitorrequesttemplate.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
diff --git a/backend/ent/client.go b/backend/ent/client.go
index ebc7fc5e..df20ddfa 100644
--- a/backend/ent/client.go
+++ b/backend/ent/client.go
@@ -25,6 +25,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/channelmonitor"
"github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
"github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
@@ -77,6 +78,8 @@ type Client struct {
ChannelMonitorDailyRollup *ChannelMonitorDailyRollupClient
// ChannelMonitorHistory is the client for interacting with the ChannelMonitorHistory builders.
ChannelMonitorHistory *ChannelMonitorHistoryClient
+ // ChannelMonitorRequestTemplate is the client for interacting with the ChannelMonitorRequestTemplate builders.
+ ChannelMonitorRequestTemplate *ChannelMonitorRequestTemplateClient
// ErrorPassthroughRule is the client for interacting with the ErrorPassthroughRule builders.
ErrorPassthroughRule *ErrorPassthroughRuleClient
// Group is the client for interacting with the Group builders.
@@ -144,6 +147,7 @@ func (c *Client) init() {
c.ChannelMonitor = NewChannelMonitorClient(c.config)
c.ChannelMonitorDailyRollup = NewChannelMonitorDailyRollupClient(c.config)
c.ChannelMonitorHistory = NewChannelMonitorHistoryClient(c.config)
+ c.ChannelMonitorRequestTemplate = NewChannelMonitorRequestTemplateClient(c.config)
c.ErrorPassthroughRule = NewErrorPassthroughRuleClient(c.config)
c.Group = NewGroupClient(c.config)
c.IdempotencyRecord = NewIdempotencyRecordClient(c.config)
@@ -257,41 +261,42 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) {
cfg := c.config
cfg.driver = tx
return &Tx{
- ctx: ctx,
- config: cfg,
- APIKey: NewAPIKeyClient(cfg),
- Account: NewAccountClient(cfg),
- AccountGroup: NewAccountGroupClient(cfg),
- Announcement: NewAnnouncementClient(cfg),
- AnnouncementRead: NewAnnouncementReadClient(cfg),
- AuthIdentity: NewAuthIdentityClient(cfg),
- AuthIdentityChannel: NewAuthIdentityChannelClient(cfg),
- ChannelMonitor: NewChannelMonitorClient(cfg),
- ChannelMonitorDailyRollup: NewChannelMonitorDailyRollupClient(cfg),
- ChannelMonitorHistory: NewChannelMonitorHistoryClient(cfg),
- ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg),
- Group: NewGroupClient(cfg),
- IdempotencyRecord: NewIdempotencyRecordClient(cfg),
- IdentityAdoptionDecision: NewIdentityAdoptionDecisionClient(cfg),
- PaymentAuditLog: NewPaymentAuditLogClient(cfg),
- PaymentOrder: NewPaymentOrderClient(cfg),
- PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg),
- PendingAuthSession: NewPendingAuthSessionClient(cfg),
- PromoCode: NewPromoCodeClient(cfg),
- PromoCodeUsage: NewPromoCodeUsageClient(cfg),
- Proxy: NewProxyClient(cfg),
- RedeemCode: NewRedeemCodeClient(cfg),
- SecuritySecret: NewSecuritySecretClient(cfg),
- Setting: NewSettingClient(cfg),
- SubscriptionPlan: NewSubscriptionPlanClient(cfg),
- TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg),
- UsageCleanupTask: NewUsageCleanupTaskClient(cfg),
- UsageLog: NewUsageLogClient(cfg),
- User: NewUserClient(cfg),
- UserAllowedGroup: NewUserAllowedGroupClient(cfg),
- UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg),
- UserAttributeValue: NewUserAttributeValueClient(cfg),
- UserSubscription: NewUserSubscriptionClient(cfg),
+ ctx: ctx,
+ config: cfg,
+ APIKey: NewAPIKeyClient(cfg),
+ Account: NewAccountClient(cfg),
+ AccountGroup: NewAccountGroupClient(cfg),
+ Announcement: NewAnnouncementClient(cfg),
+ AnnouncementRead: NewAnnouncementReadClient(cfg),
+ AuthIdentity: NewAuthIdentityClient(cfg),
+ AuthIdentityChannel: NewAuthIdentityChannelClient(cfg),
+ ChannelMonitor: NewChannelMonitorClient(cfg),
+ ChannelMonitorDailyRollup: NewChannelMonitorDailyRollupClient(cfg),
+ ChannelMonitorHistory: NewChannelMonitorHistoryClient(cfg),
+ ChannelMonitorRequestTemplate: NewChannelMonitorRequestTemplateClient(cfg),
+ ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg),
+ Group: NewGroupClient(cfg),
+ IdempotencyRecord: NewIdempotencyRecordClient(cfg),
+ IdentityAdoptionDecision: NewIdentityAdoptionDecisionClient(cfg),
+ PaymentAuditLog: NewPaymentAuditLogClient(cfg),
+ PaymentOrder: NewPaymentOrderClient(cfg),
+ PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg),
+ PendingAuthSession: NewPendingAuthSessionClient(cfg),
+ PromoCode: NewPromoCodeClient(cfg),
+ PromoCodeUsage: NewPromoCodeUsageClient(cfg),
+ Proxy: NewProxyClient(cfg),
+ RedeemCode: NewRedeemCodeClient(cfg),
+ SecuritySecret: NewSecuritySecretClient(cfg),
+ Setting: NewSettingClient(cfg),
+ SubscriptionPlan: NewSubscriptionPlanClient(cfg),
+ TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg),
+ UsageCleanupTask: NewUsageCleanupTaskClient(cfg),
+ UsageLog: NewUsageLogClient(cfg),
+ User: NewUserClient(cfg),
+ UserAllowedGroup: NewUserAllowedGroupClient(cfg),
+ UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg),
+ UserAttributeValue: NewUserAttributeValueClient(cfg),
+ UserSubscription: NewUserSubscriptionClient(cfg),
}, nil
}
@@ -309,41 +314,42 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
cfg := c.config
cfg.driver = &txDriver{tx: tx, drv: c.driver}
return &Tx{
- ctx: ctx,
- config: cfg,
- APIKey: NewAPIKeyClient(cfg),
- Account: NewAccountClient(cfg),
- AccountGroup: NewAccountGroupClient(cfg),
- Announcement: NewAnnouncementClient(cfg),
- AnnouncementRead: NewAnnouncementReadClient(cfg),
- AuthIdentity: NewAuthIdentityClient(cfg),
- AuthIdentityChannel: NewAuthIdentityChannelClient(cfg),
- ChannelMonitor: NewChannelMonitorClient(cfg),
- ChannelMonitorDailyRollup: NewChannelMonitorDailyRollupClient(cfg),
- ChannelMonitorHistory: NewChannelMonitorHistoryClient(cfg),
- ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg),
- Group: NewGroupClient(cfg),
- IdempotencyRecord: NewIdempotencyRecordClient(cfg),
- IdentityAdoptionDecision: NewIdentityAdoptionDecisionClient(cfg),
- PaymentAuditLog: NewPaymentAuditLogClient(cfg),
- PaymentOrder: NewPaymentOrderClient(cfg),
- PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg),
- PendingAuthSession: NewPendingAuthSessionClient(cfg),
- PromoCode: NewPromoCodeClient(cfg),
- PromoCodeUsage: NewPromoCodeUsageClient(cfg),
- Proxy: NewProxyClient(cfg),
- RedeemCode: NewRedeemCodeClient(cfg),
- SecuritySecret: NewSecuritySecretClient(cfg),
- Setting: NewSettingClient(cfg),
- SubscriptionPlan: NewSubscriptionPlanClient(cfg),
- TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg),
- UsageCleanupTask: NewUsageCleanupTaskClient(cfg),
- UsageLog: NewUsageLogClient(cfg),
- User: NewUserClient(cfg),
- UserAllowedGroup: NewUserAllowedGroupClient(cfg),
- UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg),
- UserAttributeValue: NewUserAttributeValueClient(cfg),
- UserSubscription: NewUserSubscriptionClient(cfg),
+ ctx: ctx,
+ config: cfg,
+ APIKey: NewAPIKeyClient(cfg),
+ Account: NewAccountClient(cfg),
+ AccountGroup: NewAccountGroupClient(cfg),
+ Announcement: NewAnnouncementClient(cfg),
+ AnnouncementRead: NewAnnouncementReadClient(cfg),
+ AuthIdentity: NewAuthIdentityClient(cfg),
+ AuthIdentityChannel: NewAuthIdentityChannelClient(cfg),
+ ChannelMonitor: NewChannelMonitorClient(cfg),
+ ChannelMonitorDailyRollup: NewChannelMonitorDailyRollupClient(cfg),
+ ChannelMonitorHistory: NewChannelMonitorHistoryClient(cfg),
+ ChannelMonitorRequestTemplate: NewChannelMonitorRequestTemplateClient(cfg),
+ ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg),
+ Group: NewGroupClient(cfg),
+ IdempotencyRecord: NewIdempotencyRecordClient(cfg),
+ IdentityAdoptionDecision: NewIdentityAdoptionDecisionClient(cfg),
+ PaymentAuditLog: NewPaymentAuditLogClient(cfg),
+ PaymentOrder: NewPaymentOrderClient(cfg),
+ PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg),
+ PendingAuthSession: NewPendingAuthSessionClient(cfg),
+ PromoCode: NewPromoCodeClient(cfg),
+ PromoCodeUsage: NewPromoCodeUsageClient(cfg),
+ Proxy: NewProxyClient(cfg),
+ RedeemCode: NewRedeemCodeClient(cfg),
+ SecuritySecret: NewSecuritySecretClient(cfg),
+ Setting: NewSettingClient(cfg),
+ SubscriptionPlan: NewSubscriptionPlanClient(cfg),
+ TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg),
+ UsageCleanupTask: NewUsageCleanupTaskClient(cfg),
+ UsageLog: NewUsageLogClient(cfg),
+ User: NewUserClient(cfg),
+ UserAllowedGroup: NewUserAllowedGroupClient(cfg),
+ UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg),
+ UserAttributeValue: NewUserAttributeValueClient(cfg),
+ UserSubscription: NewUserSubscriptionClient(cfg),
}, nil
}
@@ -375,8 +381,9 @@ func (c *Client) Use(hooks ...Hook) {
for _, n := range []interface{ Use(...Hook) }{
c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead,
c.AuthIdentity, c.AuthIdentityChannel, c.ChannelMonitor,
- c.ChannelMonitorDailyRollup, c.ChannelMonitorHistory, c.ErrorPassthroughRule,
- c.Group, c.IdempotencyRecord, c.IdentityAdoptionDecision, c.PaymentAuditLog,
+ c.ChannelMonitorDailyRollup, c.ChannelMonitorHistory,
+ c.ChannelMonitorRequestTemplate, c.ErrorPassthroughRule, c.Group,
+ c.IdempotencyRecord, c.IdentityAdoptionDecision, c.PaymentAuditLog,
c.PaymentOrder, c.PaymentProviderInstance, c.PendingAuthSession, c.PromoCode,
c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting,
c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog,
@@ -393,8 +400,9 @@ func (c *Client) Intercept(interceptors ...Interceptor) {
for _, n := range []interface{ Intercept(...Interceptor) }{
c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead,
c.AuthIdentity, c.AuthIdentityChannel, c.ChannelMonitor,
- c.ChannelMonitorDailyRollup, c.ChannelMonitorHistory, c.ErrorPassthroughRule,
- c.Group, c.IdempotencyRecord, c.IdentityAdoptionDecision, c.PaymentAuditLog,
+ c.ChannelMonitorDailyRollup, c.ChannelMonitorHistory,
+ c.ChannelMonitorRequestTemplate, c.ErrorPassthroughRule, c.Group,
+ c.IdempotencyRecord, c.IdentityAdoptionDecision, c.PaymentAuditLog,
c.PaymentOrder, c.PaymentProviderInstance, c.PendingAuthSession, c.PromoCode,
c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting,
c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog,
@@ -428,6 +436,8 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) {
return c.ChannelMonitorDailyRollup.mutate(ctx, m)
case *ChannelMonitorHistoryMutation:
return c.ChannelMonitorHistory.mutate(ctx, m)
+ case *ChannelMonitorRequestTemplateMutation:
+ return c.ChannelMonitorRequestTemplate.mutate(ctx, m)
case *ErrorPassthroughRuleMutation:
return c.ErrorPassthroughRule.mutate(ctx, m)
case *GroupMutation:
@@ -1761,6 +1771,22 @@ func (c *ChannelMonitorClient) QueryDailyRollups(_m *ChannelMonitor) *ChannelMon
return query
}
+// QueryRequestTemplate queries the request_template edge of a ChannelMonitor.
+func (c *ChannelMonitorClient) QueryRequestTemplate(_m *ChannelMonitor) *ChannelMonitorRequestTemplateQuery {
+ query := (&ChannelMonitorRequestTemplateClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(channelmonitor.Table, channelmonitor.FieldID, id),
+ sqlgraph.To(channelmonitorrequesttemplate.Table, channelmonitorrequesttemplate.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, false, channelmonitor.RequestTemplateTable, channelmonitor.RequestTemplateColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
// Hooks returns the client hooks.
func (c *ChannelMonitorClient) Hooks() []Hook {
return c.hooks.ChannelMonitor
@@ -2084,6 +2110,155 @@ func (c *ChannelMonitorHistoryClient) mutate(ctx context.Context, m *ChannelMoni
}
}
+// ChannelMonitorRequestTemplateClient is a client for the ChannelMonitorRequestTemplate schema.
+type ChannelMonitorRequestTemplateClient struct {
+ config
+}
+
+// NewChannelMonitorRequestTemplateClient returns a client for the ChannelMonitorRequestTemplate from the given config.
+func NewChannelMonitorRequestTemplateClient(c config) *ChannelMonitorRequestTemplateClient {
+ return &ChannelMonitorRequestTemplateClient{config: c}
+}
+
+// Use adds a list of mutation hooks to the hooks stack.
+// A call to `Use(f, g, h)` equals to `channelmonitorrequesttemplate.Hooks(f(g(h())))`.
+func (c *ChannelMonitorRequestTemplateClient) Use(hooks ...Hook) {
+ c.hooks.ChannelMonitorRequestTemplate = append(c.hooks.ChannelMonitorRequestTemplate, hooks...)
+}
+
+// Intercept adds a list of query interceptors to the interceptors stack.
+// A call to `Intercept(f, g, h)` equals to `channelmonitorrequesttemplate.Intercept(f(g(h())))`.
+func (c *ChannelMonitorRequestTemplateClient) Intercept(interceptors ...Interceptor) {
+ c.inters.ChannelMonitorRequestTemplate = append(c.inters.ChannelMonitorRequestTemplate, interceptors...)
+}
+
+// Create returns a builder for creating a ChannelMonitorRequestTemplate entity.
+func (c *ChannelMonitorRequestTemplateClient) Create() *ChannelMonitorRequestTemplateCreate {
+ mutation := newChannelMonitorRequestTemplateMutation(c.config, OpCreate)
+ return &ChannelMonitorRequestTemplateCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// CreateBulk returns a builder for creating a bulk of ChannelMonitorRequestTemplate entities.
+func (c *ChannelMonitorRequestTemplateClient) CreateBulk(builders ...*ChannelMonitorRequestTemplateCreate) *ChannelMonitorRequestTemplateCreateBulk {
+ return &ChannelMonitorRequestTemplateCreateBulk{config: c.config, builders: builders}
+}
+
+// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
+// a builder and applies setFunc on it.
+func (c *ChannelMonitorRequestTemplateClient) MapCreateBulk(slice any, setFunc func(*ChannelMonitorRequestTemplateCreate, int)) *ChannelMonitorRequestTemplateCreateBulk {
+ rv := reflect.ValueOf(slice)
+ if rv.Kind() != reflect.Slice {
+ return &ChannelMonitorRequestTemplateCreateBulk{err: fmt.Errorf("calling to ChannelMonitorRequestTemplateClient.MapCreateBulk with wrong type %T, need slice", slice)}
+ }
+ builders := make([]*ChannelMonitorRequestTemplateCreate, rv.Len())
+ for i := 0; i < rv.Len(); i++ {
+ builders[i] = c.Create()
+ setFunc(builders[i], i)
+ }
+ return &ChannelMonitorRequestTemplateCreateBulk{config: c.config, builders: builders}
+}
+
+// Update returns an update builder for ChannelMonitorRequestTemplate.
+func (c *ChannelMonitorRequestTemplateClient) Update() *ChannelMonitorRequestTemplateUpdate {
+ mutation := newChannelMonitorRequestTemplateMutation(c.config, OpUpdate)
+ return &ChannelMonitorRequestTemplateUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOne returns an update builder for the given entity.
+func (c *ChannelMonitorRequestTemplateClient) UpdateOne(_m *ChannelMonitorRequestTemplate) *ChannelMonitorRequestTemplateUpdateOne {
+ mutation := newChannelMonitorRequestTemplateMutation(c.config, OpUpdateOne, withChannelMonitorRequestTemplate(_m))
+ return &ChannelMonitorRequestTemplateUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOneID returns an update builder for the given id.
+func (c *ChannelMonitorRequestTemplateClient) UpdateOneID(id int64) *ChannelMonitorRequestTemplateUpdateOne {
+ mutation := newChannelMonitorRequestTemplateMutation(c.config, OpUpdateOne, withChannelMonitorRequestTemplateID(id))
+ return &ChannelMonitorRequestTemplateUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// Delete returns a delete builder for ChannelMonitorRequestTemplate.
+func (c *ChannelMonitorRequestTemplateClient) Delete() *ChannelMonitorRequestTemplateDelete {
+ mutation := newChannelMonitorRequestTemplateMutation(c.config, OpDelete)
+ return &ChannelMonitorRequestTemplateDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// DeleteOne returns a builder for deleting the given entity.
+func (c *ChannelMonitorRequestTemplateClient) DeleteOne(_m *ChannelMonitorRequestTemplate) *ChannelMonitorRequestTemplateDeleteOne {
+ return c.DeleteOneID(_m.ID)
+}
+
+// DeleteOneID returns a builder for deleting the given entity by its id.
+func (c *ChannelMonitorRequestTemplateClient) DeleteOneID(id int64) *ChannelMonitorRequestTemplateDeleteOne {
+ builder := c.Delete().Where(channelmonitorrequesttemplate.ID(id))
+ builder.mutation.id = &id
+ builder.mutation.op = OpDeleteOne
+ return &ChannelMonitorRequestTemplateDeleteOne{builder}
+}
+
+// Query returns a query builder for ChannelMonitorRequestTemplate.
+func (c *ChannelMonitorRequestTemplateClient) Query() *ChannelMonitorRequestTemplateQuery {
+ return &ChannelMonitorRequestTemplateQuery{
+ config: c.config,
+ ctx: &QueryContext{Type: TypeChannelMonitorRequestTemplate},
+ inters: c.Interceptors(),
+ }
+}
+
+// Get returns a ChannelMonitorRequestTemplate entity by its id.
+func (c *ChannelMonitorRequestTemplateClient) Get(ctx context.Context, id int64) (*ChannelMonitorRequestTemplate, error) {
+ return c.Query().Where(channelmonitorrequesttemplate.ID(id)).Only(ctx)
+}
+
+// GetX is like Get, but panics if an error occurs.
+func (c *ChannelMonitorRequestTemplateClient) GetX(ctx context.Context, id int64) *ChannelMonitorRequestTemplate {
+ obj, err := c.Get(ctx, id)
+ if err != nil {
+ panic(err)
+ }
+ return obj
+}
+
+// QueryMonitors queries the monitors edge of a ChannelMonitorRequestTemplate.
+func (c *ChannelMonitorRequestTemplateClient) QueryMonitors(_m *ChannelMonitorRequestTemplate) *ChannelMonitorQuery {
+ query := (&ChannelMonitorClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(channelmonitorrequesttemplate.Table, channelmonitorrequesttemplate.FieldID, id),
+ sqlgraph.To(channelmonitor.Table, channelmonitor.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, true, channelmonitorrequesttemplate.MonitorsTable, channelmonitorrequesttemplate.MonitorsColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// Hooks returns the client hooks.
+func (c *ChannelMonitorRequestTemplateClient) Hooks() []Hook {
+ return c.hooks.ChannelMonitorRequestTemplate
+}
+
+// Interceptors returns the client interceptors.
+func (c *ChannelMonitorRequestTemplateClient) Interceptors() []Interceptor {
+ return c.inters.ChannelMonitorRequestTemplate
+}
+
+func (c *ChannelMonitorRequestTemplateClient) mutate(ctx context.Context, m *ChannelMonitorRequestTemplateMutation) (Value, error) {
+ switch m.Op() {
+ case OpCreate:
+ return (&ChannelMonitorRequestTemplateCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdate:
+ return (&ChannelMonitorRequestTemplateUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdateOne:
+ return (&ChannelMonitorRequestTemplateUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpDelete, OpDeleteOne:
+ return (&ChannelMonitorRequestTemplateDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
+ default:
+ return nil, fmt.Errorf("ent: unknown ChannelMonitorRequestTemplate mutation op: %q", m.Op())
+ }
+}
+
// ErrorPassthroughRuleClient is a client for the ErrorPassthroughRule schema.
type ErrorPassthroughRuleClient struct {
config
@@ -5845,22 +6020,22 @@ type (
hooks struct {
APIKey, Account, AccountGroup, Announcement, AnnouncementRead, AuthIdentity,
AuthIdentityChannel, ChannelMonitor, ChannelMonitorDailyRollup,
- ChannelMonitorHistory, ErrorPassthroughRule, Group, IdempotencyRecord,
- IdentityAdoptionDecision, PaymentAuditLog, PaymentOrder,
- PaymentProviderInstance, PendingAuthSession, PromoCode, PromoCodeUsage, Proxy,
- RedeemCode, SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile,
- UsageCleanupTask, UsageLog, User, UserAllowedGroup, UserAttributeDefinition,
- UserAttributeValue, UserSubscription []ent.Hook
+ ChannelMonitorHistory, ChannelMonitorRequestTemplate, ErrorPassthroughRule,
+ Group, IdempotencyRecord, IdentityAdoptionDecision, PaymentAuditLog,
+ PaymentOrder, PaymentProviderInstance, PendingAuthSession, PromoCode,
+ PromoCodeUsage, Proxy, RedeemCode, SecuritySecret, Setting, SubscriptionPlan,
+ TLSFingerprintProfile, UsageCleanupTask, UsageLog, User, UserAllowedGroup,
+ UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Hook
}
inters struct {
APIKey, Account, AccountGroup, Announcement, AnnouncementRead, AuthIdentity,
AuthIdentityChannel, ChannelMonitor, ChannelMonitorDailyRollup,
- ChannelMonitorHistory, ErrorPassthroughRule, Group, IdempotencyRecord,
- IdentityAdoptionDecision, PaymentAuditLog, PaymentOrder,
- PaymentProviderInstance, PendingAuthSession, PromoCode, PromoCodeUsage, Proxy,
- RedeemCode, SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile,
- UsageCleanupTask, UsageLog, User, UserAllowedGroup, UserAttributeDefinition,
- UserAttributeValue, UserSubscription []ent.Interceptor
+ ChannelMonitorHistory, ChannelMonitorRequestTemplate, ErrorPassthroughRule,
+ Group, IdempotencyRecord, IdentityAdoptionDecision, PaymentAuditLog,
+ PaymentOrder, PaymentProviderInstance, PendingAuthSession, PromoCode,
+ PromoCodeUsage, Proxy, RedeemCode, SecuritySecret, Setting, SubscriptionPlan,
+ TLSFingerprintProfile, UsageCleanupTask, UsageLog, User, UserAllowedGroup,
+ UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Interceptor
}
)
diff --git a/backend/ent/ent.go b/backend/ent/ent.go
index 71d17624..c9fcc314 100644
--- a/backend/ent/ent.go
+++ b/backend/ent/ent.go
@@ -22,6 +22,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/channelmonitor"
"github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
"github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
@@ -105,39 +106,40 @@ var (
func checkColumn(t, c string) error {
initCheck.Do(func() {
columnCheck = sql.NewColumnCheck(map[string]func(string) bool{
- apikey.Table: apikey.ValidColumn,
- account.Table: account.ValidColumn,
- accountgroup.Table: accountgroup.ValidColumn,
- announcement.Table: announcement.ValidColumn,
- announcementread.Table: announcementread.ValidColumn,
- authidentity.Table: authidentity.ValidColumn,
- authidentitychannel.Table: authidentitychannel.ValidColumn,
- channelmonitor.Table: channelmonitor.ValidColumn,
- channelmonitordailyrollup.Table: channelmonitordailyrollup.ValidColumn,
- channelmonitorhistory.Table: channelmonitorhistory.ValidColumn,
- errorpassthroughrule.Table: errorpassthroughrule.ValidColumn,
- group.Table: group.ValidColumn,
- idempotencyrecord.Table: idempotencyrecord.ValidColumn,
- identityadoptiondecision.Table: identityadoptiondecision.ValidColumn,
- paymentauditlog.Table: paymentauditlog.ValidColumn,
- paymentorder.Table: paymentorder.ValidColumn,
- paymentproviderinstance.Table: paymentproviderinstance.ValidColumn,
- pendingauthsession.Table: pendingauthsession.ValidColumn,
- promocode.Table: promocode.ValidColumn,
- promocodeusage.Table: promocodeusage.ValidColumn,
- proxy.Table: proxy.ValidColumn,
- redeemcode.Table: redeemcode.ValidColumn,
- securitysecret.Table: securitysecret.ValidColumn,
- setting.Table: setting.ValidColumn,
- subscriptionplan.Table: subscriptionplan.ValidColumn,
- tlsfingerprintprofile.Table: tlsfingerprintprofile.ValidColumn,
- usagecleanuptask.Table: usagecleanuptask.ValidColumn,
- usagelog.Table: usagelog.ValidColumn,
- user.Table: user.ValidColumn,
- userallowedgroup.Table: userallowedgroup.ValidColumn,
- userattributedefinition.Table: userattributedefinition.ValidColumn,
- userattributevalue.Table: userattributevalue.ValidColumn,
- usersubscription.Table: usersubscription.ValidColumn,
+ apikey.Table: apikey.ValidColumn,
+ account.Table: account.ValidColumn,
+ accountgroup.Table: accountgroup.ValidColumn,
+ announcement.Table: announcement.ValidColumn,
+ announcementread.Table: announcementread.ValidColumn,
+ authidentity.Table: authidentity.ValidColumn,
+ authidentitychannel.Table: authidentitychannel.ValidColumn,
+ channelmonitor.Table: channelmonitor.ValidColumn,
+ channelmonitordailyrollup.Table: channelmonitordailyrollup.ValidColumn,
+ channelmonitorhistory.Table: channelmonitorhistory.ValidColumn,
+ channelmonitorrequesttemplate.Table: channelmonitorrequesttemplate.ValidColumn,
+ errorpassthroughrule.Table: errorpassthroughrule.ValidColumn,
+ group.Table: group.ValidColumn,
+ idempotencyrecord.Table: idempotencyrecord.ValidColumn,
+ identityadoptiondecision.Table: identityadoptiondecision.ValidColumn,
+ paymentauditlog.Table: paymentauditlog.ValidColumn,
+ paymentorder.Table: paymentorder.ValidColumn,
+ paymentproviderinstance.Table: paymentproviderinstance.ValidColumn,
+ pendingauthsession.Table: pendingauthsession.ValidColumn,
+ promocode.Table: promocode.ValidColumn,
+ promocodeusage.Table: promocodeusage.ValidColumn,
+ proxy.Table: proxy.ValidColumn,
+ redeemcode.Table: redeemcode.ValidColumn,
+ securitysecret.Table: securitysecret.ValidColumn,
+ setting.Table: setting.ValidColumn,
+ subscriptionplan.Table: subscriptionplan.ValidColumn,
+ tlsfingerprintprofile.Table: tlsfingerprintprofile.ValidColumn,
+ usagecleanuptask.Table: usagecleanuptask.ValidColumn,
+ usagelog.Table: usagelog.ValidColumn,
+ user.Table: user.ValidColumn,
+ userallowedgroup.Table: userallowedgroup.ValidColumn,
+ userattributedefinition.Table: userattributedefinition.ValidColumn,
+ userattributevalue.Table: userattributevalue.ValidColumn,
+ usersubscription.Table: usersubscription.ValidColumn,
})
})
return columnCheck(t, c)
diff --git a/backend/ent/hook/hook.go b/backend/ent/hook/hook.go
index ff86c90d..414eba24 100644
--- a/backend/ent/hook/hook.go
+++ b/backend/ent/hook/hook.go
@@ -129,6 +129,18 @@ func (f ChannelMonitorHistoryFunc) Mutate(ctx context.Context, m ent.Mutation) (
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.ChannelMonitorHistoryMutation", m)
}
+// The ChannelMonitorRequestTemplateFunc type is an adapter to allow the use of ordinary
+// function as ChannelMonitorRequestTemplate mutator.
+type ChannelMonitorRequestTemplateFunc func(context.Context, *ent.ChannelMonitorRequestTemplateMutation) (ent.Value, error)
+
+// Mutate calls f(ctx, m).
+func (f ChannelMonitorRequestTemplateFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
+ if mv, ok := m.(*ent.ChannelMonitorRequestTemplateMutation); ok {
+ return f(ctx, mv)
+ }
+ return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.ChannelMonitorRequestTemplateMutation", m)
+}
+
// The ErrorPassthroughRuleFunc type is an adapter to allow the use of ordinary
// function as ErrorPassthroughRule mutator.
type ErrorPassthroughRuleFunc func(context.Context, *ent.ErrorPassthroughRuleMutation) (ent.Value, error)
diff --git a/backend/ent/intercept/intercept.go b/backend/ent/intercept/intercept.go
index 0c83fc38..95b68e09 100644
--- a/backend/ent/intercept/intercept.go
+++ b/backend/ent/intercept/intercept.go
@@ -18,6 +18,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/channelmonitor"
"github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
"github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
@@ -370,6 +371,33 @@ func (f TraverseChannelMonitorHistory) Traverse(ctx context.Context, q ent.Query
return fmt.Errorf("unexpected query type %T. expect *ent.ChannelMonitorHistoryQuery", q)
}
+// The ChannelMonitorRequestTemplateFunc type is an adapter to allow the use of ordinary function as a Querier.
+type ChannelMonitorRequestTemplateFunc func(context.Context, *ent.ChannelMonitorRequestTemplateQuery) (ent.Value, error)
+
+// Query calls f(ctx, q).
+func (f ChannelMonitorRequestTemplateFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
+ if q, ok := q.(*ent.ChannelMonitorRequestTemplateQuery); ok {
+ return f(ctx, q)
+ }
+ return nil, fmt.Errorf("unexpected query type %T. expect *ent.ChannelMonitorRequestTemplateQuery", q)
+}
+
+// The TraverseChannelMonitorRequestTemplate type is an adapter to allow the use of ordinary function as Traverser.
+type TraverseChannelMonitorRequestTemplate func(context.Context, *ent.ChannelMonitorRequestTemplateQuery) error
+
+// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
+func (f TraverseChannelMonitorRequestTemplate) Intercept(next ent.Querier) ent.Querier {
+ return next
+}
+
+// Traverse calls f(ctx, q).
+func (f TraverseChannelMonitorRequestTemplate) Traverse(ctx context.Context, q ent.Query) error {
+ if q, ok := q.(*ent.ChannelMonitorRequestTemplateQuery); ok {
+ return f(ctx, q)
+ }
+ return fmt.Errorf("unexpected query type %T. expect *ent.ChannelMonitorRequestTemplateQuery", q)
+}
+
// The ErrorPassthroughRuleFunc type is an adapter to allow the use of ordinary function as a Querier.
type ErrorPassthroughRuleFunc func(context.Context, *ent.ErrorPassthroughRuleQuery) (ent.Value, error)
@@ -1014,6 +1042,8 @@ func NewQuery(q ent.Query) (Query, error) {
return &query[*ent.ChannelMonitorDailyRollupQuery, predicate.ChannelMonitorDailyRollup, channelmonitordailyrollup.OrderOption]{typ: ent.TypeChannelMonitorDailyRollup, tq: q}, nil
case *ent.ChannelMonitorHistoryQuery:
return &query[*ent.ChannelMonitorHistoryQuery, predicate.ChannelMonitorHistory, channelmonitorhistory.OrderOption]{typ: ent.TypeChannelMonitorHistory, tq: q}, nil
+ case *ent.ChannelMonitorRequestTemplateQuery:
+ return &query[*ent.ChannelMonitorRequestTemplateQuery, predicate.ChannelMonitorRequestTemplate, channelmonitorrequesttemplate.OrderOption]{typ: ent.TypeChannelMonitorRequestTemplate, tq: q}, nil
case *ent.ErrorPassthroughRuleQuery:
return &query[*ent.ErrorPassthroughRuleQuery, predicate.ErrorPassthroughRule, errorpassthroughrule.OrderOption]{typ: ent.TypeErrorPassthroughRule, tq: q}, nil
case *ent.GroupQuery:
diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go
index dba43ddf..38366e95 100644
--- a/backend/ent/migrate/schema.go
+++ b/backend/ent/migrate/schema.go
@@ -437,12 +437,24 @@ var (
{Name: "interval_seconds", Type: field.TypeInt},
{Name: "last_checked_at", Type: field.TypeTime, Nullable: true},
{Name: "created_by", Type: field.TypeInt64},
+ {Name: "extra_headers", Type: field.TypeJSON},
+ {Name: "body_override_mode", Type: field.TypeString, Size: 10, Default: "off"},
+ {Name: "body_override", Type: field.TypeJSON, Nullable: true},
+ {Name: "template_id", Type: field.TypeInt64, Nullable: true},
}
// ChannelMonitorsTable holds the schema information for the "channel_monitors" table.
ChannelMonitorsTable = &schema.Table{
Name: "channel_monitors",
Columns: ChannelMonitorsColumns,
PrimaryKey: []*schema.Column{ChannelMonitorsColumns[0]},
+ ForeignKeys: []*schema.ForeignKey{
+ {
+ Symbol: "channel_monitors_channel_monitor_request_templates_request_template",
+ Columns: []*schema.Column{ChannelMonitorsColumns[17]},
+ RefColumns: []*schema.Column{ChannelMonitorRequestTemplatesColumns[0]},
+ OnDelete: schema.SetNull,
+ },
+ },
Indexes: []*schema.Index{
{
Name: "channelmonitor_enabled_last_checked_at",
@@ -459,6 +471,11 @@ var (
Unique: false,
Columns: []*schema.Column{ChannelMonitorsColumns[9]},
},
+ {
+ Name: "channelmonitor_template_id",
+ Unique: false,
+ Columns: []*schema.Column{ChannelMonitorsColumns[17]},
+ },
},
}
// ChannelMonitorDailyRollupsColumns holds the columns for the "channel_monitor_daily_rollups" table.
@@ -542,6 +559,31 @@ var (
},
},
}
+ // ChannelMonitorRequestTemplatesColumns holds the columns for the "channel_monitor_request_templates" table.
+ ChannelMonitorRequestTemplatesColumns = []*schema.Column{
+ {Name: "id", Type: field.TypeInt64, Increment: true},
+ {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "name", Type: field.TypeString, Size: 100},
+ {Name: "provider", Type: field.TypeEnum, Enums: []string{"openai", "anthropic", "gemini"}},
+ {Name: "description", Type: field.TypeString, Nullable: true, Size: 500, Default: ""},
+ {Name: "extra_headers", Type: field.TypeJSON},
+ {Name: "body_override_mode", Type: field.TypeString, Size: 10, Default: "off"},
+ {Name: "body_override", Type: field.TypeJSON, Nullable: true},
+ }
+ // ChannelMonitorRequestTemplatesTable holds the schema information for the "channel_monitor_request_templates" table.
+ ChannelMonitorRequestTemplatesTable = &schema.Table{
+ Name: "channel_monitor_request_templates",
+ Columns: ChannelMonitorRequestTemplatesColumns,
+ PrimaryKey: []*schema.Column{ChannelMonitorRequestTemplatesColumns[0]},
+ Indexes: []*schema.Index{
+ {
+ Name: "channelmonitorrequesttemplate_provider_name",
+ Unique: true,
+ Columns: []*schema.Column{ChannelMonitorRequestTemplatesColumns[4], ChannelMonitorRequestTemplatesColumns[3]},
+ },
+ },
+ }
// ErrorPassthroughRulesColumns holds the columns for the "error_passthrough_rules" table.
ErrorPassthroughRulesColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
@@ -1644,6 +1686,7 @@ var (
ChannelMonitorsTable,
ChannelMonitorDailyRollupsTable,
ChannelMonitorHistoriesTable,
+ ChannelMonitorRequestTemplatesTable,
ErrorPassthroughRulesTable,
GroupsTable,
IdempotencyRecordsTable,
@@ -1701,6 +1744,7 @@ func init() {
AuthIdentityChannelsTable.Annotation = &entsql.Annotation{
Table: "auth_identity_channels",
}
+ ChannelMonitorsTable.ForeignKeys[0].RefTable = ChannelMonitorRequestTemplatesTable
ChannelMonitorsTable.Annotation = &entsql.Annotation{
Table: "channel_monitors",
}
@@ -1712,6 +1756,9 @@ func init() {
ChannelMonitorHistoriesTable.Annotation = &entsql.Annotation{
Table: "channel_monitor_histories",
}
+ ChannelMonitorRequestTemplatesTable.Annotation = &entsql.Annotation{
+ Table: "channel_monitor_request_templates",
+ }
ErrorPassthroughRulesTable.Annotation = &entsql.Annotation{
Table: "error_passthrough_rules",
}
diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go
index 43e52371..568b3eb5 100644
--- a/backend/ent/mutation.go
+++ b/backend/ent/mutation.go
@@ -22,6 +22,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/channelmonitor"
"github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
"github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
@@ -58,39 +59,40 @@ const (
OpUpdateOne = ent.OpUpdateOne
// Node types.
- TypeAPIKey = "APIKey"
- TypeAccount = "Account"
- TypeAccountGroup = "AccountGroup"
- TypeAnnouncement = "Announcement"
- TypeAnnouncementRead = "AnnouncementRead"
- TypeAuthIdentity = "AuthIdentity"
- TypeAuthIdentityChannel = "AuthIdentityChannel"
- TypeChannelMonitor = "ChannelMonitor"
- TypeChannelMonitorDailyRollup = "ChannelMonitorDailyRollup"
- TypeChannelMonitorHistory = "ChannelMonitorHistory"
- TypeErrorPassthroughRule = "ErrorPassthroughRule"
- TypeGroup = "Group"
- TypeIdempotencyRecord = "IdempotencyRecord"
- TypeIdentityAdoptionDecision = "IdentityAdoptionDecision"
- TypePaymentAuditLog = "PaymentAuditLog"
- TypePaymentOrder = "PaymentOrder"
- TypePaymentProviderInstance = "PaymentProviderInstance"
- TypePendingAuthSession = "PendingAuthSession"
- TypePromoCode = "PromoCode"
- TypePromoCodeUsage = "PromoCodeUsage"
- TypeProxy = "Proxy"
- TypeRedeemCode = "RedeemCode"
- TypeSecuritySecret = "SecuritySecret"
- TypeSetting = "Setting"
- TypeSubscriptionPlan = "SubscriptionPlan"
- TypeTLSFingerprintProfile = "TLSFingerprintProfile"
- TypeUsageCleanupTask = "UsageCleanupTask"
- TypeUsageLog = "UsageLog"
- TypeUser = "User"
- TypeUserAllowedGroup = "UserAllowedGroup"
- TypeUserAttributeDefinition = "UserAttributeDefinition"
- TypeUserAttributeValue = "UserAttributeValue"
- TypeUserSubscription = "UserSubscription"
+ TypeAPIKey = "APIKey"
+ TypeAccount = "Account"
+ TypeAccountGroup = "AccountGroup"
+ TypeAnnouncement = "Announcement"
+ TypeAnnouncementRead = "AnnouncementRead"
+ TypeAuthIdentity = "AuthIdentity"
+ TypeAuthIdentityChannel = "AuthIdentityChannel"
+ TypeChannelMonitor = "ChannelMonitor"
+ TypeChannelMonitorDailyRollup = "ChannelMonitorDailyRollup"
+ TypeChannelMonitorHistory = "ChannelMonitorHistory"
+ TypeChannelMonitorRequestTemplate = "ChannelMonitorRequestTemplate"
+ TypeErrorPassthroughRule = "ErrorPassthroughRule"
+ TypeGroup = "Group"
+ TypeIdempotencyRecord = "IdempotencyRecord"
+ TypeIdentityAdoptionDecision = "IdentityAdoptionDecision"
+ TypePaymentAuditLog = "PaymentAuditLog"
+ TypePaymentOrder = "PaymentOrder"
+ TypePaymentProviderInstance = "PaymentProviderInstance"
+ TypePendingAuthSession = "PendingAuthSession"
+ TypePromoCode = "PromoCode"
+ TypePromoCodeUsage = "PromoCodeUsage"
+ TypeProxy = "Proxy"
+ TypeRedeemCode = "RedeemCode"
+ TypeSecuritySecret = "SecuritySecret"
+ TypeSetting = "Setting"
+ TypeSubscriptionPlan = "SubscriptionPlan"
+ TypeTLSFingerprintProfile = "TLSFingerprintProfile"
+ TypeUsageCleanupTask = "UsageCleanupTask"
+ TypeUsageLog = "UsageLog"
+ TypeUser = "User"
+ TypeUserAllowedGroup = "UserAllowedGroup"
+ TypeUserAttributeDefinition = "UserAttributeDefinition"
+ TypeUserAttributeValue = "UserAttributeValue"
+ TypeUserSubscription = "UserSubscription"
)
// APIKeyMutation represents an operation that mutates the APIKey nodes in the graph.
@@ -8743,35 +8745,40 @@ func (m *AuthIdentityChannelMutation) ResetEdge(name string) error {
// ChannelMonitorMutation represents an operation that mutates the ChannelMonitor nodes in the graph.
type ChannelMonitorMutation struct {
config
- op Op
- typ string
- id *int64
- created_at *time.Time
- updated_at *time.Time
- name *string
- provider *channelmonitor.Provider
- endpoint *string
- api_key_encrypted *string
- primary_model *string
- extra_models *[]string
- appendextra_models []string
- group_name *string
- enabled *bool
- interval_seconds *int
- addinterval_seconds *int
- last_checked_at *time.Time
- created_by *int64
- addcreated_by *int64
- clearedFields map[string]struct{}
- history map[int64]struct{}
- removedhistory map[int64]struct{}
- clearedhistory bool
- daily_rollups map[int64]struct{}
- removeddaily_rollups map[int64]struct{}
- cleareddaily_rollups bool
- done bool
- oldValue func(context.Context) (*ChannelMonitor, error)
- predicates []predicate.ChannelMonitor
+ op Op
+ typ string
+ id *int64
+ created_at *time.Time
+ updated_at *time.Time
+ name *string
+ provider *channelmonitor.Provider
+ endpoint *string
+ api_key_encrypted *string
+ primary_model *string
+ extra_models *[]string
+ appendextra_models []string
+ group_name *string
+ enabled *bool
+ interval_seconds *int
+ addinterval_seconds *int
+ last_checked_at *time.Time
+ created_by *int64
+ addcreated_by *int64
+ extra_headers *map[string]string
+ body_override_mode *string
+ body_override *map[string]interface{}
+ clearedFields map[string]struct{}
+ history map[int64]struct{}
+ removedhistory map[int64]struct{}
+ clearedhistory bool
+ daily_rollups map[int64]struct{}
+ removeddaily_rollups map[int64]struct{}
+ cleareddaily_rollups bool
+ request_template *int64
+ clearedrequest_template bool
+ done bool
+ oldValue func(context.Context) (*ChannelMonitor, error)
+ predicates []predicate.ChannelMonitor
}
var _ ent.Mutation = (*ChannelMonitorMutation)(nil)
@@ -9421,6 +9428,176 @@ func (m *ChannelMonitorMutation) ResetCreatedBy() {
m.addcreated_by = nil
}
+// SetTemplateID sets the "template_id" field.
+func (m *ChannelMonitorMutation) SetTemplateID(i int64) {
+ m.request_template = &i
+}
+
+// TemplateID returns the value of the "template_id" field in the mutation.
+func (m *ChannelMonitorMutation) TemplateID() (r int64, exists bool) {
+ v := m.request_template
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldTemplateID returns the old "template_id" field's value of the ChannelMonitor entity.
+// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorMutation) OldTemplateID(ctx context.Context) (v *int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldTemplateID is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldTemplateID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldTemplateID: %w", err)
+ }
+ return oldValue.TemplateID, nil
+}
+
+// ClearTemplateID clears the value of the "template_id" field.
+func (m *ChannelMonitorMutation) ClearTemplateID() {
+ m.request_template = nil
+ m.clearedFields[channelmonitor.FieldTemplateID] = struct{}{}
+}
+
+// TemplateIDCleared returns if the "template_id" field was cleared in this mutation.
+func (m *ChannelMonitorMutation) TemplateIDCleared() bool {
+ _, ok := m.clearedFields[channelmonitor.FieldTemplateID]
+ return ok
+}
+
+// ResetTemplateID resets all changes to the "template_id" field.
+func (m *ChannelMonitorMutation) ResetTemplateID() {
+ m.request_template = nil
+ delete(m.clearedFields, channelmonitor.FieldTemplateID)
+}
+
+// SetExtraHeaders sets the "extra_headers" field.
+func (m *ChannelMonitorMutation) SetExtraHeaders(value map[string]string) {
+ m.extra_headers = &value
+}
+
+// ExtraHeaders returns the value of the "extra_headers" field in the mutation.
+func (m *ChannelMonitorMutation) ExtraHeaders() (r map[string]string, exists bool) {
+ v := m.extra_headers
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldExtraHeaders returns the old "extra_headers" field's value of the ChannelMonitor entity.
+// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorMutation) OldExtraHeaders(ctx context.Context) (v map[string]string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldExtraHeaders is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldExtraHeaders requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldExtraHeaders: %w", err)
+ }
+ return oldValue.ExtraHeaders, nil
+}
+
+// ResetExtraHeaders resets all changes to the "extra_headers" field.
+func (m *ChannelMonitorMutation) ResetExtraHeaders() {
+ m.extra_headers = nil
+}
+
+// SetBodyOverrideMode sets the "body_override_mode" field.
+func (m *ChannelMonitorMutation) SetBodyOverrideMode(s string) {
+ m.body_override_mode = &s
+}
+
+// BodyOverrideMode returns the value of the "body_override_mode" field in the mutation.
+func (m *ChannelMonitorMutation) BodyOverrideMode() (r string, exists bool) {
+ v := m.body_override_mode
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldBodyOverrideMode returns the old "body_override_mode" field's value of the ChannelMonitor entity.
+// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorMutation) OldBodyOverrideMode(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldBodyOverrideMode is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldBodyOverrideMode requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldBodyOverrideMode: %w", err)
+ }
+ return oldValue.BodyOverrideMode, nil
+}
+
+// ResetBodyOverrideMode resets all changes to the "body_override_mode" field.
+func (m *ChannelMonitorMutation) ResetBodyOverrideMode() {
+ m.body_override_mode = nil
+}
+
+// SetBodyOverride sets the "body_override" field.
+func (m *ChannelMonitorMutation) SetBodyOverride(value map[string]interface{}) {
+ m.body_override = &value
+}
+
+// BodyOverride returns the value of the "body_override" field in the mutation.
+func (m *ChannelMonitorMutation) BodyOverride() (r map[string]interface{}, exists bool) {
+ v := m.body_override
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldBodyOverride returns the old "body_override" field's value of the ChannelMonitor entity.
+// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorMutation) OldBodyOverride(ctx context.Context) (v map[string]interface{}, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldBodyOverride is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldBodyOverride requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldBodyOverride: %w", err)
+ }
+ return oldValue.BodyOverride, nil
+}
+
+// ClearBodyOverride clears the value of the "body_override" field.
+func (m *ChannelMonitorMutation) ClearBodyOverride() {
+ m.body_override = nil
+ m.clearedFields[channelmonitor.FieldBodyOverride] = struct{}{}
+}
+
+// BodyOverrideCleared returns if the "body_override" field was cleared in this mutation.
+func (m *ChannelMonitorMutation) BodyOverrideCleared() bool {
+ _, ok := m.clearedFields[channelmonitor.FieldBodyOverride]
+ return ok
+}
+
+// ResetBodyOverride resets all changes to the "body_override" field.
+func (m *ChannelMonitorMutation) ResetBodyOverride() {
+ m.body_override = nil
+ delete(m.clearedFields, channelmonitor.FieldBodyOverride)
+}
+
// AddHistoryIDs adds the "history" edge to the ChannelMonitorHistory entity by ids.
func (m *ChannelMonitorMutation) AddHistoryIDs(ids ...int64) {
if m.history == nil {
@@ -9529,6 +9706,46 @@ func (m *ChannelMonitorMutation) ResetDailyRollups() {
m.removeddaily_rollups = nil
}
+// SetRequestTemplateID sets the "request_template" edge to the ChannelMonitorRequestTemplate entity by id.
+func (m *ChannelMonitorMutation) SetRequestTemplateID(id int64) {
+ m.request_template = &id
+}
+
+// ClearRequestTemplate clears the "request_template" edge to the ChannelMonitorRequestTemplate entity.
+func (m *ChannelMonitorMutation) ClearRequestTemplate() {
+ m.clearedrequest_template = true
+ m.clearedFields[channelmonitor.FieldTemplateID] = struct{}{}
+}
+
+// RequestTemplateCleared reports if the "request_template" edge to the ChannelMonitorRequestTemplate entity was cleared.
+func (m *ChannelMonitorMutation) RequestTemplateCleared() bool {
+ return m.TemplateIDCleared() || m.clearedrequest_template
+}
+
+// RequestTemplateID returns the "request_template" edge ID in the mutation.
+func (m *ChannelMonitorMutation) RequestTemplateID() (id int64, exists bool) {
+ if m.request_template != nil {
+ return *m.request_template, true
+ }
+ return
+}
+
+// RequestTemplateIDs returns the "request_template" edge IDs in the mutation.
+// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use
+// RequestTemplateID instead. It exists only for internal usage by the builders.
+func (m *ChannelMonitorMutation) RequestTemplateIDs() (ids []int64) {
+ if id := m.request_template; id != nil {
+ ids = append(ids, *id)
+ }
+ return
+}
+
+// ResetRequestTemplate resets all changes to the "request_template" edge.
+func (m *ChannelMonitorMutation) ResetRequestTemplate() {
+ m.request_template = nil
+ m.clearedrequest_template = false
+}
+
// Where appends a list predicates to the ChannelMonitorMutation builder.
func (m *ChannelMonitorMutation) Where(ps ...predicate.ChannelMonitor) {
m.predicates = append(m.predicates, ps...)
@@ -9563,7 +9780,7 @@ func (m *ChannelMonitorMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *ChannelMonitorMutation) Fields() []string {
- fields := make([]string, 0, 13)
+ fields := make([]string, 0, 17)
if m.created_at != nil {
fields = append(fields, channelmonitor.FieldCreatedAt)
}
@@ -9603,6 +9820,18 @@ func (m *ChannelMonitorMutation) Fields() []string {
if m.created_by != nil {
fields = append(fields, channelmonitor.FieldCreatedBy)
}
+ if m.request_template != nil {
+ fields = append(fields, channelmonitor.FieldTemplateID)
+ }
+ if m.extra_headers != nil {
+ fields = append(fields, channelmonitor.FieldExtraHeaders)
+ }
+ if m.body_override_mode != nil {
+ fields = append(fields, channelmonitor.FieldBodyOverrideMode)
+ }
+ if m.body_override != nil {
+ fields = append(fields, channelmonitor.FieldBodyOverride)
+ }
return fields
}
@@ -9637,6 +9866,14 @@ func (m *ChannelMonitorMutation) Field(name string) (ent.Value, bool) {
return m.LastCheckedAt()
case channelmonitor.FieldCreatedBy:
return m.CreatedBy()
+ case channelmonitor.FieldTemplateID:
+ return m.TemplateID()
+ case channelmonitor.FieldExtraHeaders:
+ return m.ExtraHeaders()
+ case channelmonitor.FieldBodyOverrideMode:
+ return m.BodyOverrideMode()
+ case channelmonitor.FieldBodyOverride:
+ return m.BodyOverride()
}
return nil, false
}
@@ -9672,6 +9909,14 @@ func (m *ChannelMonitorMutation) OldField(ctx context.Context, name string) (ent
return m.OldLastCheckedAt(ctx)
case channelmonitor.FieldCreatedBy:
return m.OldCreatedBy(ctx)
+ case channelmonitor.FieldTemplateID:
+ return m.OldTemplateID(ctx)
+ case channelmonitor.FieldExtraHeaders:
+ return m.OldExtraHeaders(ctx)
+ case channelmonitor.FieldBodyOverrideMode:
+ return m.OldBodyOverrideMode(ctx)
+ case channelmonitor.FieldBodyOverride:
+ return m.OldBodyOverride(ctx)
}
return nil, fmt.Errorf("unknown ChannelMonitor field %s", name)
}
@@ -9772,6 +10017,34 @@ func (m *ChannelMonitorMutation) SetField(name string, value ent.Value) error {
}
m.SetCreatedBy(v)
return nil
+ case channelmonitor.FieldTemplateID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetTemplateID(v)
+ return nil
+ case channelmonitor.FieldExtraHeaders:
+ v, ok := value.(map[string]string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetExtraHeaders(v)
+ return nil
+ case channelmonitor.FieldBodyOverrideMode:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetBodyOverrideMode(v)
+ return nil
+ case channelmonitor.FieldBodyOverride:
+ v, ok := value.(map[string]interface{})
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetBodyOverride(v)
+ return nil
}
return fmt.Errorf("unknown ChannelMonitor field %s", name)
}
@@ -9835,6 +10108,12 @@ func (m *ChannelMonitorMutation) ClearedFields() []string {
if m.FieldCleared(channelmonitor.FieldLastCheckedAt) {
fields = append(fields, channelmonitor.FieldLastCheckedAt)
}
+ if m.FieldCleared(channelmonitor.FieldTemplateID) {
+ fields = append(fields, channelmonitor.FieldTemplateID)
+ }
+ if m.FieldCleared(channelmonitor.FieldBodyOverride) {
+ fields = append(fields, channelmonitor.FieldBodyOverride)
+ }
return fields
}
@@ -9855,6 +10134,12 @@ func (m *ChannelMonitorMutation) ClearField(name string) error {
case channelmonitor.FieldLastCheckedAt:
m.ClearLastCheckedAt()
return nil
+ case channelmonitor.FieldTemplateID:
+ m.ClearTemplateID()
+ return nil
+ case channelmonitor.FieldBodyOverride:
+ m.ClearBodyOverride()
+ return nil
}
return fmt.Errorf("unknown ChannelMonitor nullable field %s", name)
}
@@ -9902,19 +10187,34 @@ func (m *ChannelMonitorMutation) ResetField(name string) error {
case channelmonitor.FieldCreatedBy:
m.ResetCreatedBy()
return nil
+ case channelmonitor.FieldTemplateID:
+ m.ResetTemplateID()
+ return nil
+ case channelmonitor.FieldExtraHeaders:
+ m.ResetExtraHeaders()
+ return nil
+ case channelmonitor.FieldBodyOverrideMode:
+ m.ResetBodyOverrideMode()
+ return nil
+ case channelmonitor.FieldBodyOverride:
+ m.ResetBodyOverride()
+ return nil
}
return fmt.Errorf("unknown ChannelMonitor field %s", name)
}
// AddedEdges returns all edge names that were set/added in this mutation.
func (m *ChannelMonitorMutation) AddedEdges() []string {
- edges := make([]string, 0, 2)
+ edges := make([]string, 0, 3)
if m.history != nil {
edges = append(edges, channelmonitor.EdgeHistory)
}
if m.daily_rollups != nil {
edges = append(edges, channelmonitor.EdgeDailyRollups)
}
+ if m.request_template != nil {
+ edges = append(edges, channelmonitor.EdgeRequestTemplate)
+ }
return edges
}
@@ -9934,13 +10234,17 @@ func (m *ChannelMonitorMutation) AddedIDs(name string) []ent.Value {
ids = append(ids, id)
}
return ids
+ case channelmonitor.EdgeRequestTemplate:
+ if id := m.request_template; id != nil {
+ return []ent.Value{*id}
+ }
}
return nil
}
// RemovedEdges returns all edge names that were removed in this mutation.
func (m *ChannelMonitorMutation) RemovedEdges() []string {
- edges := make([]string, 0, 2)
+ edges := make([]string, 0, 3)
if m.removedhistory != nil {
edges = append(edges, channelmonitor.EdgeHistory)
}
@@ -9972,13 +10276,16 @@ func (m *ChannelMonitorMutation) RemovedIDs(name string) []ent.Value {
// ClearedEdges returns all edge names that were cleared in this mutation.
func (m *ChannelMonitorMutation) ClearedEdges() []string {
- edges := make([]string, 0, 2)
+ edges := make([]string, 0, 3)
if m.clearedhistory {
edges = append(edges, channelmonitor.EdgeHistory)
}
if m.cleareddaily_rollups {
edges = append(edges, channelmonitor.EdgeDailyRollups)
}
+ if m.clearedrequest_template {
+ edges = append(edges, channelmonitor.EdgeRequestTemplate)
+ }
return edges
}
@@ -9990,6 +10297,8 @@ func (m *ChannelMonitorMutation) EdgeCleared(name string) bool {
return m.clearedhistory
case channelmonitor.EdgeDailyRollups:
return m.cleareddaily_rollups
+ case channelmonitor.EdgeRequestTemplate:
+ return m.clearedrequest_template
}
return false
}
@@ -9998,6 +10307,9 @@ func (m *ChannelMonitorMutation) EdgeCleared(name string) bool {
// if that edge is not defined in the schema.
func (m *ChannelMonitorMutation) ClearEdge(name string) error {
switch name {
+ case channelmonitor.EdgeRequestTemplate:
+ m.ClearRequestTemplate()
+ return nil
}
return fmt.Errorf("unknown ChannelMonitor unique edge %s", name)
}
@@ -10012,6 +10324,9 @@ func (m *ChannelMonitorMutation) ResetEdge(name string) error {
case channelmonitor.EdgeDailyRollups:
m.ResetDailyRollups()
return nil
+ case channelmonitor.EdgeRequestTemplate:
+ m.ResetRequestTemplate()
+ return nil
}
return fmt.Errorf("unknown ChannelMonitor edge %s", name)
}
@@ -12266,6 +12581,844 @@ func (m *ChannelMonitorHistoryMutation) ResetEdge(name string) error {
return fmt.Errorf("unknown ChannelMonitorHistory edge %s", name)
}
+// ChannelMonitorRequestTemplateMutation represents an operation that mutates the ChannelMonitorRequestTemplate nodes in the graph.
+type ChannelMonitorRequestTemplateMutation struct {
+ config
+ op Op
+ typ string
+ id *int64
+ created_at *time.Time
+ updated_at *time.Time
+ name *string
+ provider *channelmonitorrequesttemplate.Provider
+ description *string
+ extra_headers *map[string]string
+ body_override_mode *string
+ body_override *map[string]interface{}
+ clearedFields map[string]struct{}
+ monitors map[int64]struct{}
+ removedmonitors map[int64]struct{}
+ clearedmonitors bool
+ done bool
+ oldValue func(context.Context) (*ChannelMonitorRequestTemplate, error)
+ predicates []predicate.ChannelMonitorRequestTemplate
+}
+
+var _ ent.Mutation = (*ChannelMonitorRequestTemplateMutation)(nil)
+
+// channelmonitorrequesttemplateOption allows management of the mutation configuration using functional options.
+type channelmonitorrequesttemplateOption func(*ChannelMonitorRequestTemplateMutation)
+
+// newChannelMonitorRequestTemplateMutation creates new mutation for the ChannelMonitorRequestTemplate entity.
+func newChannelMonitorRequestTemplateMutation(c config, op Op, opts ...channelmonitorrequesttemplateOption) *ChannelMonitorRequestTemplateMutation {
+ m := &ChannelMonitorRequestTemplateMutation{
+ config: c,
+ op: op,
+ typ: TypeChannelMonitorRequestTemplate,
+ clearedFields: make(map[string]struct{}),
+ }
+ for _, opt := range opts {
+ opt(m)
+ }
+ return m
+}
+
+// withChannelMonitorRequestTemplateID sets the ID field of the mutation.
+func withChannelMonitorRequestTemplateID(id int64) channelmonitorrequesttemplateOption {
+ return func(m *ChannelMonitorRequestTemplateMutation) {
+ var (
+ err error
+ once sync.Once
+ value *ChannelMonitorRequestTemplate
+ )
+ m.oldValue = func(ctx context.Context) (*ChannelMonitorRequestTemplate, error) {
+ once.Do(func() {
+ if m.done {
+ err = errors.New("querying old values post mutation is not allowed")
+ } else {
+ value, err = m.Client().ChannelMonitorRequestTemplate.Get(ctx, id)
+ }
+ })
+ return value, err
+ }
+ m.id = &id
+ }
+}
+
+// withChannelMonitorRequestTemplate sets the old ChannelMonitorRequestTemplate of the mutation.
+func withChannelMonitorRequestTemplate(node *ChannelMonitorRequestTemplate) channelmonitorrequesttemplateOption {
+ return func(m *ChannelMonitorRequestTemplateMutation) {
+ m.oldValue = func(context.Context) (*ChannelMonitorRequestTemplate, error) {
+ return node, nil
+ }
+ m.id = &node.ID
+ }
+}
+
+// Client returns a new `ent.Client` from the mutation. If the mutation was
+// executed in a transaction (ent.Tx), a transactional client is returned.
+func (m ChannelMonitorRequestTemplateMutation) Client() *Client {
+ client := &Client{config: m.config}
+ client.init()
+ return client
+}
+
+// Tx returns an `ent.Tx` for mutations that were executed in transactions;
+// it returns an error otherwise.
+func (m ChannelMonitorRequestTemplateMutation) Tx() (*Tx, error) {
+ if _, ok := m.driver.(*txDriver); !ok {
+ return nil, errors.New("ent: mutation is not running in a transaction")
+ }
+ tx := &Tx{config: m.config}
+ tx.init()
+ return tx, nil
+}
+
+// ID returns the ID value in the mutation. Note that the ID is only available
+// if it was provided to the builder or after it was returned from the database.
+func (m *ChannelMonitorRequestTemplateMutation) ID() (id int64, exists bool) {
+ if m.id == nil {
+ return
+ }
+ return *m.id, true
+}
+
+// IDs queries the database and returns the entity ids that match the mutation's predicate.
+// That means, if the mutation is applied within a transaction with an isolation level such
+// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
+// or updated by the mutation.
+func (m *ChannelMonitorRequestTemplateMutation) IDs(ctx context.Context) ([]int64, error) {
+ switch {
+ case m.op.Is(OpUpdateOne | OpDeleteOne):
+ id, exists := m.ID()
+ if exists {
+ return []int64{id}, nil
+ }
+ fallthrough
+ case m.op.Is(OpUpdate | OpDelete):
+ return m.Client().ChannelMonitorRequestTemplate.Query().Where(m.predicates...).IDs(ctx)
+ default:
+ return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
+ }
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (m *ChannelMonitorRequestTemplateMutation) SetCreatedAt(t time.Time) {
+ m.created_at = &t
+}
+
+// CreatedAt returns the value of the "created_at" field in the mutation.
+func (m *ChannelMonitorRequestTemplateMutation) CreatedAt() (r time.Time, exists bool) {
+ v := m.created_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCreatedAt returns the old "created_at" field's value of the ChannelMonitorRequestTemplate entity.
+// If the ChannelMonitorRequestTemplate object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorRequestTemplateMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCreatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err)
+ }
+ return oldValue.CreatedAt, nil
+}
+
+// ResetCreatedAt resets all changes to the "created_at" field.
+func (m *ChannelMonitorRequestTemplateMutation) ResetCreatedAt() {
+ m.created_at = nil
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (m *ChannelMonitorRequestTemplateMutation) SetUpdatedAt(t time.Time) {
+ m.updated_at = &t
+}
+
+// UpdatedAt returns the value of the "updated_at" field in the mutation.
+func (m *ChannelMonitorRequestTemplateMutation) UpdatedAt() (r time.Time, exists bool) {
+ v := m.updated_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldUpdatedAt returns the old "updated_at" field's value of the ChannelMonitorRequestTemplate entity.
+// If the ChannelMonitorRequestTemplate object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorRequestTemplateMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldUpdatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err)
+ }
+ return oldValue.UpdatedAt, nil
+}
+
+// ResetUpdatedAt resets all changes to the "updated_at" field.
+func (m *ChannelMonitorRequestTemplateMutation) ResetUpdatedAt() {
+ m.updated_at = nil
+}
+
+// SetName sets the "name" field.
+func (m *ChannelMonitorRequestTemplateMutation) SetName(s string) {
+ m.name = &s
+}
+
+// Name returns the value of the "name" field in the mutation.
+func (m *ChannelMonitorRequestTemplateMutation) Name() (r string, exists bool) {
+ v := m.name
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldName returns the old "name" field's value of the ChannelMonitorRequestTemplate entity.
+// If the ChannelMonitorRequestTemplate object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorRequestTemplateMutation) OldName(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldName is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldName requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldName: %w", err)
+ }
+ return oldValue.Name, nil
+}
+
+// ResetName resets all changes to the "name" field.
+func (m *ChannelMonitorRequestTemplateMutation) ResetName() {
+ m.name = nil
+}
+
+// SetProvider sets the "provider" field.
+func (m *ChannelMonitorRequestTemplateMutation) SetProvider(c channelmonitorrequesttemplate.Provider) {
+ m.provider = &c
+}
+
+// Provider returns the value of the "provider" field in the mutation.
+func (m *ChannelMonitorRequestTemplateMutation) Provider() (r channelmonitorrequesttemplate.Provider, exists bool) {
+ v := m.provider
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProvider returns the old "provider" field's value of the ChannelMonitorRequestTemplate entity.
+// If the ChannelMonitorRequestTemplate object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorRequestTemplateMutation) OldProvider(ctx context.Context) (v channelmonitorrequesttemplate.Provider, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProvider is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProvider requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProvider: %w", err)
+ }
+ return oldValue.Provider, nil
+}
+
+// ResetProvider resets all changes to the "provider" field.
+func (m *ChannelMonitorRequestTemplateMutation) ResetProvider() {
+ m.provider = nil
+}
+
+// SetDescription sets the "description" field.
+func (m *ChannelMonitorRequestTemplateMutation) SetDescription(s string) {
+ m.description = &s
+}
+
+// Description returns the value of the "description" field in the mutation.
+func (m *ChannelMonitorRequestTemplateMutation) Description() (r string, exists bool) {
+ v := m.description
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldDescription returns the old "description" field's value of the ChannelMonitorRequestTemplate entity.
+// If the ChannelMonitorRequestTemplate object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorRequestTemplateMutation) OldDescription(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldDescription is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldDescription requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldDescription: %w", err)
+ }
+ return oldValue.Description, nil
+}
+
+// ClearDescription clears the value of the "description" field.
+func (m *ChannelMonitorRequestTemplateMutation) ClearDescription() {
+ m.description = nil
+ m.clearedFields[channelmonitorrequesttemplate.FieldDescription] = struct{}{}
+}
+
+// DescriptionCleared returns if the "description" field was cleared in this mutation.
+func (m *ChannelMonitorRequestTemplateMutation) DescriptionCleared() bool {
+ _, ok := m.clearedFields[channelmonitorrequesttemplate.FieldDescription]
+ return ok
+}
+
+// ResetDescription resets all changes to the "description" field.
+func (m *ChannelMonitorRequestTemplateMutation) ResetDescription() {
+ m.description = nil
+ delete(m.clearedFields, channelmonitorrequesttemplate.FieldDescription)
+}
+
+// SetExtraHeaders sets the "extra_headers" field.
+func (m *ChannelMonitorRequestTemplateMutation) SetExtraHeaders(value map[string]string) {
+ m.extra_headers = &value
+}
+
+// ExtraHeaders returns the value of the "extra_headers" field in the mutation.
+func (m *ChannelMonitorRequestTemplateMutation) ExtraHeaders() (r map[string]string, exists bool) {
+ v := m.extra_headers
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldExtraHeaders returns the old "extra_headers" field's value of the ChannelMonitorRequestTemplate entity.
+// If the ChannelMonitorRequestTemplate object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorRequestTemplateMutation) OldExtraHeaders(ctx context.Context) (v map[string]string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldExtraHeaders is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldExtraHeaders requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldExtraHeaders: %w", err)
+ }
+ return oldValue.ExtraHeaders, nil
+}
+
+// ResetExtraHeaders resets all changes to the "extra_headers" field.
+func (m *ChannelMonitorRequestTemplateMutation) ResetExtraHeaders() {
+ m.extra_headers = nil
+}
+
+// SetBodyOverrideMode sets the "body_override_mode" field.
+func (m *ChannelMonitorRequestTemplateMutation) SetBodyOverrideMode(s string) {
+ m.body_override_mode = &s
+}
+
+// BodyOverrideMode returns the value of the "body_override_mode" field in the mutation.
+func (m *ChannelMonitorRequestTemplateMutation) BodyOverrideMode() (r string, exists bool) {
+ v := m.body_override_mode
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldBodyOverrideMode returns the old "body_override_mode" field's value of the ChannelMonitorRequestTemplate entity.
+// If the ChannelMonitorRequestTemplate object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorRequestTemplateMutation) OldBodyOverrideMode(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldBodyOverrideMode is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldBodyOverrideMode requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldBodyOverrideMode: %w", err)
+ }
+ return oldValue.BodyOverrideMode, nil
+}
+
+// ResetBodyOverrideMode resets all changes to the "body_override_mode" field.
+func (m *ChannelMonitorRequestTemplateMutation) ResetBodyOverrideMode() {
+ m.body_override_mode = nil
+}
+
+// SetBodyOverride sets the "body_override" field.
+func (m *ChannelMonitorRequestTemplateMutation) SetBodyOverride(value map[string]interface{}) {
+ m.body_override = &value
+}
+
+// BodyOverride returns the value of the "body_override" field in the mutation.
+func (m *ChannelMonitorRequestTemplateMutation) BodyOverride() (r map[string]interface{}, exists bool) {
+ v := m.body_override
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldBodyOverride returns the old "body_override" field's value of the ChannelMonitorRequestTemplate entity.
+// If the ChannelMonitorRequestTemplate object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorRequestTemplateMutation) OldBodyOverride(ctx context.Context) (v map[string]interface{}, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldBodyOverride is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldBodyOverride requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldBodyOverride: %w", err)
+ }
+ return oldValue.BodyOverride, nil
+}
+
+// ClearBodyOverride clears the value of the "body_override" field.
+func (m *ChannelMonitorRequestTemplateMutation) ClearBodyOverride() {
+ m.body_override = nil
+ m.clearedFields[channelmonitorrequesttemplate.FieldBodyOverride] = struct{}{}
+}
+
+// BodyOverrideCleared returns if the "body_override" field was cleared in this mutation.
+func (m *ChannelMonitorRequestTemplateMutation) BodyOverrideCleared() bool {
+ _, ok := m.clearedFields[channelmonitorrequesttemplate.FieldBodyOverride]
+ return ok
+}
+
+// ResetBodyOverride resets all changes to the "body_override" field.
+func (m *ChannelMonitorRequestTemplateMutation) ResetBodyOverride() {
+ m.body_override = nil
+ delete(m.clearedFields, channelmonitorrequesttemplate.FieldBodyOverride)
+}
+
+// AddMonitorIDs adds the "monitors" edge to the ChannelMonitor entity by ids.
+func (m *ChannelMonitorRequestTemplateMutation) AddMonitorIDs(ids ...int64) {
+ if m.monitors == nil {
+ m.monitors = make(map[int64]struct{})
+ }
+ for i := range ids {
+ m.monitors[ids[i]] = struct{}{}
+ }
+}
+
+// ClearMonitors clears the "monitors" edge to the ChannelMonitor entity.
+func (m *ChannelMonitorRequestTemplateMutation) ClearMonitors() {
+ m.clearedmonitors = true
+}
+
+// MonitorsCleared reports if the "monitors" edge to the ChannelMonitor entity was cleared.
+func (m *ChannelMonitorRequestTemplateMutation) MonitorsCleared() bool {
+ return m.clearedmonitors
+}
+
+// RemoveMonitorIDs removes the "monitors" edge to the ChannelMonitor entity by IDs.
+func (m *ChannelMonitorRequestTemplateMutation) RemoveMonitorIDs(ids ...int64) {
+ if m.removedmonitors == nil {
+ m.removedmonitors = make(map[int64]struct{})
+ }
+ for i := range ids {
+ delete(m.monitors, ids[i])
+ m.removedmonitors[ids[i]] = struct{}{}
+ }
+}
+
+// RemovedMonitors returns the removed IDs of the "monitors" edge to the ChannelMonitor entity.
+func (m *ChannelMonitorRequestTemplateMutation) RemovedMonitorsIDs() (ids []int64) {
+ for id := range m.removedmonitors {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// MonitorsIDs returns the "monitors" edge IDs in the mutation.
+func (m *ChannelMonitorRequestTemplateMutation) MonitorsIDs() (ids []int64) {
+ for id := range m.monitors {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// ResetMonitors resets all changes to the "monitors" edge.
+func (m *ChannelMonitorRequestTemplateMutation) ResetMonitors() {
+ m.monitors = nil
+ m.clearedmonitors = false
+ m.removedmonitors = nil
+}
+
+// Where appends a list predicates to the ChannelMonitorRequestTemplateMutation builder.
+func (m *ChannelMonitorRequestTemplateMutation) Where(ps ...predicate.ChannelMonitorRequestTemplate) {
+ m.predicates = append(m.predicates, ps...)
+}
+
+// WhereP appends storage-level predicates to the ChannelMonitorRequestTemplateMutation builder. Using this method,
+// users can use type-assertion to append predicates that do not depend on any generated package.
+func (m *ChannelMonitorRequestTemplateMutation) WhereP(ps ...func(*sql.Selector)) {
+ p := make([]predicate.ChannelMonitorRequestTemplate, len(ps))
+ for i := range ps {
+ p[i] = ps[i]
+ }
+ m.Where(p...)
+}
+
+// Op returns the operation name.
+func (m *ChannelMonitorRequestTemplateMutation) Op() Op {
+ return m.op
+}
+
+// SetOp allows setting the mutation operation.
+func (m *ChannelMonitorRequestTemplateMutation) SetOp(op Op) {
+ m.op = op
+}
+
+// Type returns the node type of this mutation (ChannelMonitorRequestTemplate).
+func (m *ChannelMonitorRequestTemplateMutation) Type() string {
+ return m.typ
+}
+
+// Fields returns all fields that were changed during this mutation. Note that in
+// order to get all numeric fields that were incremented/decremented, call
+// AddedFields().
+func (m *ChannelMonitorRequestTemplateMutation) Fields() []string {
+ fields := make([]string, 0, 8)
+ if m.created_at != nil {
+ fields = append(fields, channelmonitorrequesttemplate.FieldCreatedAt)
+ }
+ if m.updated_at != nil {
+ fields = append(fields, channelmonitorrequesttemplate.FieldUpdatedAt)
+ }
+ if m.name != nil {
+ fields = append(fields, channelmonitorrequesttemplate.FieldName)
+ }
+ if m.provider != nil {
+ fields = append(fields, channelmonitorrequesttemplate.FieldProvider)
+ }
+ if m.description != nil {
+ fields = append(fields, channelmonitorrequesttemplate.FieldDescription)
+ }
+ if m.extra_headers != nil {
+ fields = append(fields, channelmonitorrequesttemplate.FieldExtraHeaders)
+ }
+ if m.body_override_mode != nil {
+ fields = append(fields, channelmonitorrequesttemplate.FieldBodyOverrideMode)
+ }
+ if m.body_override != nil {
+ fields = append(fields, channelmonitorrequesttemplate.FieldBodyOverride)
+ }
+ return fields
+}
+
+// Field returns the value of a field with the given name. The second boolean
+// return value indicates that this field was not set, or was not defined in the
+// schema.
+func (m *ChannelMonitorRequestTemplateMutation) Field(name string) (ent.Value, bool) {
+ switch name {
+ case channelmonitorrequesttemplate.FieldCreatedAt:
+ return m.CreatedAt()
+ case channelmonitorrequesttemplate.FieldUpdatedAt:
+ return m.UpdatedAt()
+ case channelmonitorrequesttemplate.FieldName:
+ return m.Name()
+ case channelmonitorrequesttemplate.FieldProvider:
+ return m.Provider()
+ case channelmonitorrequesttemplate.FieldDescription:
+ return m.Description()
+ case channelmonitorrequesttemplate.FieldExtraHeaders:
+ return m.ExtraHeaders()
+ case channelmonitorrequesttemplate.FieldBodyOverrideMode:
+ return m.BodyOverrideMode()
+ case channelmonitorrequesttemplate.FieldBodyOverride:
+ return m.BodyOverride()
+ }
+ return nil, false
+}
+
+// OldField returns the old value of the field from the database. An error is
+// returned if the mutation operation is not UpdateOne, or the query to the
+// database failed.
+func (m *ChannelMonitorRequestTemplateMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
+ switch name {
+ case channelmonitorrequesttemplate.FieldCreatedAt:
+ return m.OldCreatedAt(ctx)
+ case channelmonitorrequesttemplate.FieldUpdatedAt:
+ return m.OldUpdatedAt(ctx)
+ case channelmonitorrequesttemplate.FieldName:
+ return m.OldName(ctx)
+ case channelmonitorrequesttemplate.FieldProvider:
+ return m.OldProvider(ctx)
+ case channelmonitorrequesttemplate.FieldDescription:
+ return m.OldDescription(ctx)
+ case channelmonitorrequesttemplate.FieldExtraHeaders:
+ return m.OldExtraHeaders(ctx)
+ case channelmonitorrequesttemplate.FieldBodyOverrideMode:
+ return m.OldBodyOverrideMode(ctx)
+ case channelmonitorrequesttemplate.FieldBodyOverride:
+ return m.OldBodyOverride(ctx)
+ }
+ return nil, fmt.Errorf("unknown ChannelMonitorRequestTemplate field %s", name)
+}
+
+// SetField sets the value of a field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *ChannelMonitorRequestTemplateMutation) SetField(name string, value ent.Value) error {
+ switch name {
+ case channelmonitorrequesttemplate.FieldCreatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCreatedAt(v)
+ return nil
+ case channelmonitorrequesttemplate.FieldUpdatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUpdatedAt(v)
+ return nil
+ case channelmonitorrequesttemplate.FieldName:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetName(v)
+ return nil
+ case channelmonitorrequesttemplate.FieldProvider:
+ v, ok := value.(channelmonitorrequesttemplate.Provider)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProvider(v)
+ return nil
+ case channelmonitorrequesttemplate.FieldDescription:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetDescription(v)
+ return nil
+ case channelmonitorrequesttemplate.FieldExtraHeaders:
+ v, ok := value.(map[string]string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetExtraHeaders(v)
+ return nil
+ case channelmonitorrequesttemplate.FieldBodyOverrideMode:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetBodyOverrideMode(v)
+ return nil
+ case channelmonitorrequesttemplate.FieldBodyOverride:
+ v, ok := value.(map[string]interface{})
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetBodyOverride(v)
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitorRequestTemplate field %s", name)
+}
+
+// AddedFields returns all numeric fields that were incremented/decremented during
+// this mutation.
+func (m *ChannelMonitorRequestTemplateMutation) AddedFields() []string {
+ return nil
+}
+
+// AddedField returns the numeric value that was incremented/decremented on a field
+// with the given name. The second boolean return value indicates that this field
+// was not set, or was not defined in the schema.
+func (m *ChannelMonitorRequestTemplateMutation) AddedField(name string) (ent.Value, bool) {
+ return nil, false
+}
+
+// AddField adds the value to the field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *ChannelMonitorRequestTemplateMutation) AddField(name string, value ent.Value) error {
+ switch name {
+ }
+ return fmt.Errorf("unknown ChannelMonitorRequestTemplate numeric field %s", name)
+}
+
+// ClearedFields returns all nullable fields that were cleared during this
+// mutation.
+func (m *ChannelMonitorRequestTemplateMutation) ClearedFields() []string {
+ var fields []string
+ if m.FieldCleared(channelmonitorrequesttemplate.FieldDescription) {
+ fields = append(fields, channelmonitorrequesttemplate.FieldDescription)
+ }
+ if m.FieldCleared(channelmonitorrequesttemplate.FieldBodyOverride) {
+ fields = append(fields, channelmonitorrequesttemplate.FieldBodyOverride)
+ }
+ return fields
+}
+
+// FieldCleared returns a boolean indicating if a field with the given name was
+// cleared in this mutation.
+func (m *ChannelMonitorRequestTemplateMutation) FieldCleared(name string) bool {
+ _, ok := m.clearedFields[name]
+ return ok
+}
+
+// ClearField clears the value of the field with the given name. It returns an
+// error if the field is not defined in the schema.
+func (m *ChannelMonitorRequestTemplateMutation) ClearField(name string) error {
+ switch name {
+ case channelmonitorrequesttemplate.FieldDescription:
+ m.ClearDescription()
+ return nil
+ case channelmonitorrequesttemplate.FieldBodyOverride:
+ m.ClearBodyOverride()
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitorRequestTemplate nullable field %s", name)
+}
+
+// ResetField resets all changes in the mutation for the field with the given name.
+// It returns an error if the field is not defined in the schema.
+func (m *ChannelMonitorRequestTemplateMutation) ResetField(name string) error {
+ switch name {
+ case channelmonitorrequesttemplate.FieldCreatedAt:
+ m.ResetCreatedAt()
+ return nil
+ case channelmonitorrequesttemplate.FieldUpdatedAt:
+ m.ResetUpdatedAt()
+ return nil
+ case channelmonitorrequesttemplate.FieldName:
+ m.ResetName()
+ return nil
+ case channelmonitorrequesttemplate.FieldProvider:
+ m.ResetProvider()
+ return nil
+ case channelmonitorrequesttemplate.FieldDescription:
+ m.ResetDescription()
+ return nil
+ case channelmonitorrequesttemplate.FieldExtraHeaders:
+ m.ResetExtraHeaders()
+ return nil
+ case channelmonitorrequesttemplate.FieldBodyOverrideMode:
+ m.ResetBodyOverrideMode()
+ return nil
+ case channelmonitorrequesttemplate.FieldBodyOverride:
+ m.ResetBodyOverride()
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitorRequestTemplate field %s", name)
+}
+
+// AddedEdges returns all edge names that were set/added in this mutation.
+func (m *ChannelMonitorRequestTemplateMutation) AddedEdges() []string {
+ edges := make([]string, 0, 1)
+ if m.monitors != nil {
+ edges = append(edges, channelmonitorrequesttemplate.EdgeMonitors)
+ }
+ return edges
+}
+
+// AddedIDs returns all IDs (to other nodes) that were added for the given edge
+// name in this mutation.
+func (m *ChannelMonitorRequestTemplateMutation) AddedIDs(name string) []ent.Value {
+ switch name {
+ case channelmonitorrequesttemplate.EdgeMonitors:
+ ids := make([]ent.Value, 0, len(m.monitors))
+ for id := range m.monitors {
+ ids = append(ids, id)
+ }
+ return ids
+ }
+ return nil
+}
+
+// RemovedEdges returns all edge names that were removed in this mutation.
+func (m *ChannelMonitorRequestTemplateMutation) RemovedEdges() []string {
+ edges := make([]string, 0, 1)
+ if m.removedmonitors != nil {
+ edges = append(edges, channelmonitorrequesttemplate.EdgeMonitors)
+ }
+ return edges
+}
+
+// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
+// the given name in this mutation.
+func (m *ChannelMonitorRequestTemplateMutation) RemovedIDs(name string) []ent.Value {
+ switch name {
+ case channelmonitorrequesttemplate.EdgeMonitors:
+ ids := make([]ent.Value, 0, len(m.removedmonitors))
+ for id := range m.removedmonitors {
+ ids = append(ids, id)
+ }
+ return ids
+ }
+ return nil
+}
+
+// ClearedEdges returns all edge names that were cleared in this mutation.
+func (m *ChannelMonitorRequestTemplateMutation) ClearedEdges() []string {
+ edges := make([]string, 0, 1)
+ if m.clearedmonitors {
+ edges = append(edges, channelmonitorrequesttemplate.EdgeMonitors)
+ }
+ return edges
+}
+
+// EdgeCleared returns a boolean which indicates if the edge with the given name
+// was cleared in this mutation.
+func (m *ChannelMonitorRequestTemplateMutation) EdgeCleared(name string) bool {
+ switch name {
+ case channelmonitorrequesttemplate.EdgeMonitors:
+ return m.clearedmonitors
+ }
+ return false
+}
+
+// ClearEdge clears the value of the edge with the given name. It returns an error
+// if that edge is not defined in the schema.
+func (m *ChannelMonitorRequestTemplateMutation) ClearEdge(name string) error {
+ switch name {
+ }
+ return fmt.Errorf("unknown ChannelMonitorRequestTemplate unique edge %s", name)
+}
+
+// ResetEdge resets all changes to the edge with the given name in this mutation.
+// It returns an error if the edge is not defined in the schema.
+func (m *ChannelMonitorRequestTemplateMutation) ResetEdge(name string) error {
+ switch name {
+ case channelmonitorrequesttemplate.EdgeMonitors:
+ m.ResetMonitors()
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitorRequestTemplate edge %s", name)
+}
+
// ErrorPassthroughRuleMutation represents an operation that mutates the ErrorPassthroughRule nodes in the graph.
type ErrorPassthroughRuleMutation struct {
config
diff --git a/backend/ent/predicate/predicate.go b/backend/ent/predicate/predicate.go
index adb9a085..dc86471e 100644
--- a/backend/ent/predicate/predicate.go
+++ b/backend/ent/predicate/predicate.go
@@ -36,6 +36,9 @@ type ChannelMonitorDailyRollup func(*sql.Selector)
// ChannelMonitorHistory is the predicate function for channelmonitorhistory builders.
type ChannelMonitorHistory func(*sql.Selector)
+// ChannelMonitorRequestTemplate is the predicate function for channelmonitorrequesttemplate builders.
+type ChannelMonitorRequestTemplate func(*sql.Selector)
+
// ErrorPassthroughRule is the predicate function for errorpassthroughrule builders.
type ErrorPassthroughRule func(*sql.Selector)
diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go
index 63552bb5..aaa939c5 100644
--- a/backend/ent/runtime/runtime.go
+++ b/backend/ent/runtime/runtime.go
@@ -15,6 +15,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/channelmonitor"
"github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
"github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
@@ -521,6 +522,16 @@ func init() {
channelmonitorDescIntervalSeconds := channelmonitorFields[8].Descriptor()
// channelmonitor.IntervalSecondsValidator is a validator for the "interval_seconds" field. It is called by the builders before save.
channelmonitor.IntervalSecondsValidator = channelmonitorDescIntervalSeconds.Validators[0].(func(int) error)
+ // channelmonitorDescExtraHeaders is the schema descriptor for extra_headers field.
+ channelmonitorDescExtraHeaders := channelmonitorFields[12].Descriptor()
+ // channelmonitor.DefaultExtraHeaders holds the default value on creation for the extra_headers field.
+ channelmonitor.DefaultExtraHeaders = channelmonitorDescExtraHeaders.Default.(map[string]string)
+ // channelmonitorDescBodyOverrideMode is the schema descriptor for body_override_mode field.
+ channelmonitorDescBodyOverrideMode := channelmonitorFields[13].Descriptor()
+ // channelmonitor.DefaultBodyOverrideMode holds the default value on creation for the body_override_mode field.
+ channelmonitor.DefaultBodyOverrideMode = channelmonitorDescBodyOverrideMode.Default.(string)
+ // channelmonitor.BodyOverrideModeValidator is a validator for the "body_override_mode" field. It is called by the builders before save.
+ channelmonitor.BodyOverrideModeValidator = channelmonitorDescBodyOverrideMode.Validators[0].(func(string) error)
channelmonitordailyrollupFields := schema.ChannelMonitorDailyRollup{}.Fields()
_ = channelmonitordailyrollupFields
// channelmonitordailyrollupDescModel is the schema descriptor for model field.
@@ -617,6 +628,55 @@ func init() {
channelmonitorhistoryDescCheckedAt := channelmonitorhistoryFields[6].Descriptor()
// channelmonitorhistory.DefaultCheckedAt holds the default value on creation for the checked_at field.
channelmonitorhistory.DefaultCheckedAt = channelmonitorhistoryDescCheckedAt.Default.(func() time.Time)
+ channelmonitorrequesttemplateMixin := schema.ChannelMonitorRequestTemplate{}.Mixin()
+ channelmonitorrequesttemplateMixinFields0 := channelmonitorrequesttemplateMixin[0].Fields()
+ _ = channelmonitorrequesttemplateMixinFields0
+ channelmonitorrequesttemplateFields := schema.ChannelMonitorRequestTemplate{}.Fields()
+ _ = channelmonitorrequesttemplateFields
+ // channelmonitorrequesttemplateDescCreatedAt is the schema descriptor for created_at field.
+ channelmonitorrequesttemplateDescCreatedAt := channelmonitorrequesttemplateMixinFields0[0].Descriptor()
+ // channelmonitorrequesttemplate.DefaultCreatedAt holds the default value on creation for the created_at field.
+ channelmonitorrequesttemplate.DefaultCreatedAt = channelmonitorrequesttemplateDescCreatedAt.Default.(func() time.Time)
+ // channelmonitorrequesttemplateDescUpdatedAt is the schema descriptor for updated_at field.
+ channelmonitorrequesttemplateDescUpdatedAt := channelmonitorrequesttemplateMixinFields0[1].Descriptor()
+ // channelmonitorrequesttemplate.DefaultUpdatedAt holds the default value on creation for the updated_at field.
+ channelmonitorrequesttemplate.DefaultUpdatedAt = channelmonitorrequesttemplateDescUpdatedAt.Default.(func() time.Time)
+ // channelmonitorrequesttemplate.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
+ channelmonitorrequesttemplate.UpdateDefaultUpdatedAt = channelmonitorrequesttemplateDescUpdatedAt.UpdateDefault.(func() time.Time)
+ // channelmonitorrequesttemplateDescName is the schema descriptor for name field.
+ channelmonitorrequesttemplateDescName := channelmonitorrequesttemplateFields[0].Descriptor()
+ // channelmonitorrequesttemplate.NameValidator is a validator for the "name" field. It is called by the builders before save.
+ channelmonitorrequesttemplate.NameValidator = func() func(string) error {
+ validators := channelmonitorrequesttemplateDescName.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ }
+ return func(name string) error {
+ for _, fn := range fns {
+ if err := fn(name); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // channelmonitorrequesttemplateDescDescription is the schema descriptor for description field.
+ channelmonitorrequesttemplateDescDescription := channelmonitorrequesttemplateFields[2].Descriptor()
+ // channelmonitorrequesttemplate.DefaultDescription holds the default value on creation for the description field.
+ channelmonitorrequesttemplate.DefaultDescription = channelmonitorrequesttemplateDescDescription.Default.(string)
+ // channelmonitorrequesttemplate.DescriptionValidator is a validator for the "description" field. It is called by the builders before save.
+ channelmonitorrequesttemplate.DescriptionValidator = channelmonitorrequesttemplateDescDescription.Validators[0].(func(string) error)
+ // channelmonitorrequesttemplateDescExtraHeaders is the schema descriptor for extra_headers field.
+ channelmonitorrequesttemplateDescExtraHeaders := channelmonitorrequesttemplateFields[3].Descriptor()
+ // channelmonitorrequesttemplate.DefaultExtraHeaders holds the default value on creation for the extra_headers field.
+ channelmonitorrequesttemplate.DefaultExtraHeaders = channelmonitorrequesttemplateDescExtraHeaders.Default.(map[string]string)
+ // channelmonitorrequesttemplateDescBodyOverrideMode is the schema descriptor for body_override_mode field.
+ channelmonitorrequesttemplateDescBodyOverrideMode := channelmonitorrequesttemplateFields[4].Descriptor()
+ // channelmonitorrequesttemplate.DefaultBodyOverrideMode holds the default value on creation for the body_override_mode field.
+ channelmonitorrequesttemplate.DefaultBodyOverrideMode = channelmonitorrequesttemplateDescBodyOverrideMode.Default.(string)
+ // channelmonitorrequesttemplate.BodyOverrideModeValidator is a validator for the "body_override_mode" field. It is called by the builders before save.
+ channelmonitorrequesttemplate.BodyOverrideModeValidator = channelmonitorrequesttemplateDescBodyOverrideMode.Validators[0].(func(string) error)
errorpassthroughruleMixin := schema.ErrorPassthroughRule{}.Mixin()
errorpassthroughruleMixinFields0 := errorpassthroughruleMixin[0].Fields()
_ = errorpassthroughruleMixinFields0
diff --git a/backend/ent/schema/channel_monitor.go b/backend/ent/schema/channel_monitor.go
index f6a6578d..355ade4b 100644
--- a/backend/ent/schema/channel_monitor.go
+++ b/backend/ent/schema/channel_monitor.go
@@ -62,6 +62,26 @@ func (ChannelMonitor) Fields() []ent.Field {
Optional().
Nillable(),
field.Int64("created_by"),
+
+ // ---- 自定义请求快照字段(来自模板 / 手动编辑) ----
+
+ // template_id: 关联的请求模板 ID(仅用于 UI 分组 + 一键应用)。
+ // 实际运行时 checker 只读下面 3 个快照字段,**不再回查模板表**。
+ // 模板被删除时此字段会被 SET NULL(见 Edges 的 OnDelete 注解)。
+ field.Int64("template_id").
+ Optional().
+ Nillable(),
+ // extra_headers: 自定义 HTTP 头快照(来自模板 or 用户手填)。
+ // 运行时 merge 进 adapter 默认 headers。
+ field.JSON("extra_headers", map[string]string{}).
+ Default(map[string]string{}),
+ // body_override_mode: 同 ChannelMonitorRequestTemplate.body_override_mode
+ field.String("body_override_mode").
+ Default("off").
+ MaxLen(10),
+ // body_override: 同 ChannelMonitorRequestTemplate.body_override
+ field.JSON("body_override", map[string]any{}).
+ Optional(),
}
}
@@ -71,6 +91,12 @@ func (ChannelMonitor) Edges() []ent.Edge {
Annotations(entsql.OnDelete(entsql.Cascade)),
edge.To("daily_rollups", ChannelMonitorDailyRollup.Type).
Annotations(entsql.OnDelete(entsql.Cascade)),
+ // 关联请求模板:模板被删除时 template_id 自动置空,
+ // 监控本身保留(继续用快照字段跑)。
+ edge.To("request_template", ChannelMonitorRequestTemplate.Type).
+ Field("template_id").
+ Unique().
+ Annotations(entsql.OnDelete(entsql.SetNull)),
}
}
@@ -79,5 +105,6 @@ func (ChannelMonitor) Indexes() []ent.Index {
index.Fields("enabled", "last_checked_at"),
index.Fields("provider"),
index.Fields("group_name"),
+ index.Fields("template_id"),
}
}
diff --git a/backend/ent/schema/channel_monitor_request_template.go b/backend/ent/schema/channel_monitor_request_template.go
new file mode 100644
index 00000000..59df2f29
--- /dev/null
+++ b/backend/ent/schema/channel_monitor_request_template.go
@@ -0,0 +1,80 @@
+package schema
+
+import (
+ "github.com/Wei-Shaw/sub2api/ent/schema/mixins"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/entsql"
+ "entgo.io/ent/schema"
+ "entgo.io/ent/schema/edge"
+ "entgo.io/ent/schema/field"
+ "entgo.io/ent/schema/index"
+)
+
+// ChannelMonitorRequestTemplate 请求模板:一组可复用的 headers + 可选 body 覆盖配置。
+//
+// 语义为快照:模板被"应用"到监控时,extra_headers / body_override_mode / body_override
+// 会被**拷贝**到 channel_monitors 同名字段;后续模板变动不会自动影响已应用的监控——
+// 必须用户主动在模板编辑 Dialog 里点「应用到关联监控」才会覆盖快照。
+// 这样模板改错不会瞬间打挂所有已经跑起来的监控。
+type ChannelMonitorRequestTemplate struct {
+ ent.Schema
+}
+
+func (ChannelMonitorRequestTemplate) Annotations() []schema.Annotation {
+ return []schema.Annotation{
+ entsql.Annotation{Table: "channel_monitor_request_templates"},
+ }
+}
+
+func (ChannelMonitorRequestTemplate) Mixin() []ent.Mixin {
+ return []ent.Mixin{
+ mixins.TimeMixin{},
+ }
+}
+
+func (ChannelMonitorRequestTemplate) Fields() []ent.Field {
+ return []ent.Field{
+ field.String("name").
+ NotEmpty().
+ MaxLen(100),
+ field.Enum("provider").
+ Values("openai", "anthropic", "gemini"),
+ field.String("description").
+ Optional().
+ Default("").
+ MaxLen(500),
+ // extra_headers: 用户自定义 HTTP 头(如 User-Agent 伪装)。
+ // 运行时 merge 进 adapter 默认 headers,用户值优先;
+ // hop-by-hop 黑名单(Host/Content-Length/...)由 checker 过滤。
+ field.JSON("extra_headers", map[string]string{}).
+ Default(map[string]string{}),
+ // body_override_mode: 'off' | 'merge' | 'replace'
+ // off - 用 adapter 默认 body(忽略 body_override)
+ // merge - adapter 默认 body 与 body_override 浅合并(body_override 优先,
+ // model/messages/contents 等关键字段在 checker 里走黑名单跳过)
+ // replace - 直接用 body_override 作为完整 body;此时跳过 challenge 校验,
+ // 改为 HTTP 2xx + 响应文本非空即视为可用
+ field.String("body_override_mode").
+ Default("off").
+ MaxLen(10),
+ // body_override: JSON 对象,根据 body_override_mode 使用。
+ // 用 map[string]any 以便前端传任意结构(含嵌套)。
+ field.JSON("body_override", map[string]any{}).
+ Optional(),
+ }
+}
+
+func (ChannelMonitorRequestTemplate) Edges() []ent.Edge {
+ return []ent.Edge{
+ edge.From("monitors", ChannelMonitor.Type).
+ Ref("request_template"),
+ }
+}
+
+func (ChannelMonitorRequestTemplate) Indexes() []ent.Index {
+ return []ent.Index{
+ // 同一 provider 内 name 唯一:允许 Anthropic + OpenAI 重名 "伪装官方客户端"。
+ index.Fields("provider", "name").Unique(),
+ }
+}
diff --git a/backend/ent/tx.go b/backend/ent/tx.go
index 0e65a940..611028e9 100644
--- a/backend/ent/tx.go
+++ b/backend/ent/tx.go
@@ -34,6 +34,8 @@ type Tx struct {
ChannelMonitorDailyRollup *ChannelMonitorDailyRollupClient
// ChannelMonitorHistory is the client for interacting with the ChannelMonitorHistory builders.
ChannelMonitorHistory *ChannelMonitorHistoryClient
+ // ChannelMonitorRequestTemplate is the client for interacting with the ChannelMonitorRequestTemplate builders.
+ ChannelMonitorRequestTemplate *ChannelMonitorRequestTemplateClient
// ErrorPassthroughRule is the client for interacting with the ErrorPassthroughRule builders.
ErrorPassthroughRule *ErrorPassthroughRuleClient
// Group is the client for interacting with the Group builders.
@@ -221,6 +223,7 @@ func (tx *Tx) init() {
tx.ChannelMonitor = NewChannelMonitorClient(tx.config)
tx.ChannelMonitorDailyRollup = NewChannelMonitorDailyRollupClient(tx.config)
tx.ChannelMonitorHistory = NewChannelMonitorHistoryClient(tx.config)
+ tx.ChannelMonitorRequestTemplate = NewChannelMonitorRequestTemplateClient(tx.config)
tx.ErrorPassthroughRule = NewErrorPassthroughRuleClient(tx.config)
tx.Group = NewGroupClient(tx.config)
tx.IdempotencyRecord = NewIdempotencyRecordClient(tx.config)
diff --git a/backend/internal/handler/admin/channel_monitor_handler.go b/backend/internal/handler/admin/channel_monitor_handler.go
index ce86c3dc..e92c81fe 100644
--- a/backend/internal/handler/admin/channel_monitor_handler.go
+++ b/backend/internal/handler/admin/channel_monitor_handler.go
@@ -36,27 +36,36 @@ func NewChannelMonitorHandler(monitorService *service.ChannelMonitorService) *Ch
// --- Request / Response ---
type channelMonitorCreateRequest struct {
- Name string `json:"name" binding:"required,max=100"`
- Provider string `json:"provider" binding:"required,oneof=openai anthropic gemini"`
- Endpoint string `json:"endpoint" binding:"required,max=500"`
- APIKey string `json:"api_key" binding:"required,max=2000"`
- PrimaryModel string `json:"primary_model" binding:"required,max=200"`
- ExtraModels []string `json:"extra_models"`
- GroupName string `json:"group_name" binding:"max=100"`
- Enabled *bool `json:"enabled"`
- IntervalSeconds int `json:"interval_seconds" binding:"required,min=15,max=3600"`
+ Name string `json:"name" binding:"required,max=100"`
+ Provider string `json:"provider" binding:"required,oneof=openai anthropic gemini"`
+ Endpoint string `json:"endpoint" binding:"required,max=500"`
+ APIKey string `json:"api_key" binding:"required,max=2000"`
+ PrimaryModel string `json:"primary_model" binding:"required,max=200"`
+ ExtraModels []string `json:"extra_models"`
+ GroupName string `json:"group_name" binding:"max=100"`
+ Enabled *bool `json:"enabled"`
+ IntervalSeconds int `json:"interval_seconds" binding:"required,min=15,max=3600"`
+ TemplateID *int64 `json:"template_id"`
+ ExtraHeaders map[string]string `json:"extra_headers"`
+ BodyOverrideMode string `json:"body_override_mode" binding:"omitempty,oneof=off merge replace"`
+ BodyOverride map[string]any `json:"body_override"`
}
type channelMonitorUpdateRequest struct {
- Name *string `json:"name" binding:"omitempty,max=100"`
- Provider *string `json:"provider" binding:"omitempty,oneof=openai anthropic gemini"`
- Endpoint *string `json:"endpoint" binding:"omitempty,max=500"`
- APIKey *string `json:"api_key" binding:"omitempty,max=2000"`
- PrimaryModel *string `json:"primary_model" binding:"omitempty,max=200"`
- ExtraModels *[]string `json:"extra_models"`
- GroupName *string `json:"group_name" binding:"omitempty,max=100"`
- Enabled *bool `json:"enabled"`
- IntervalSeconds *int `json:"interval_seconds" binding:"omitempty,min=15,max=3600"`
+ Name *string `json:"name" binding:"omitempty,max=100"`
+ Provider *string `json:"provider" binding:"omitempty,oneof=openai anthropic gemini"`
+ Endpoint *string `json:"endpoint" binding:"omitempty,max=500"`
+ APIKey *string `json:"api_key" binding:"omitempty,max=2000"`
+ PrimaryModel *string `json:"primary_model" binding:"omitempty,max=200"`
+ ExtraModels *[]string `json:"extra_models"`
+ GroupName *string `json:"group_name" binding:"omitempty,max=100"`
+ Enabled *bool `json:"enabled"`
+ IntervalSeconds *int `json:"interval_seconds" binding:"omitempty,min=15,max=3600"`
+ TemplateID *int64 `json:"template_id"`
+ ClearTemplate bool `json:"clear_template"` // true 时把 template_id 置空,忽略 TemplateID
+ ExtraHeaders *map[string]string `json:"extra_headers"`
+ BodyOverrideMode *string `json:"body_override_mode" binding:"omitempty,oneof=off merge replace"`
+ BodyOverride *map[string]any `json:"body_override"`
}
type channelMonitorResponse struct {
@@ -79,6 +88,11 @@ type channelMonitorResponse struct {
PrimaryLatencyMs *int `json:"primary_latency_ms"`
Availability7d float64 `json:"availability_7d"`
ExtraModelsStatus []dto.ChannelMonitorExtraModelStatus `json:"extra_models_status"`
+ // 请求自定义快照:前端编辑 / 展示「高级设置」用
+ TemplateID *int64 `json:"template_id"`
+ ExtraHeaders map[string]string `json:"extra_headers"`
+ BodyOverrideMode string `json:"body_override_mode"`
+ BodyOverride map[string]any `json:"body_override"`
}
type channelMonitorCheckResultResponse struct {
@@ -116,6 +130,10 @@ func channelMonitorToResponse(m *service.ChannelMonitor) *channelMonitorResponse
if extras == nil {
extras = []string{}
}
+ headers := m.ExtraHeaders
+ if headers == nil {
+ headers = map[string]string{}
+ }
resp := &channelMonitorResponse{
ID: m.ID,
Name: m.Name,
@@ -131,6 +149,10 @@ func channelMonitorToResponse(m *service.ChannelMonitor) *channelMonitorResponse
CreatedBy: m.CreatedBy,
CreatedAt: m.CreatedAt.UTC().Format(time.RFC3339),
UpdatedAt: m.UpdatedAt.UTC().Format(time.RFC3339),
+ TemplateID: m.TemplateID,
+ ExtraHeaders: headers,
+ BodyOverrideMode: m.BodyOverrideMode,
+ BodyOverride: m.BodyOverride,
// PrimaryStatus / PrimaryLatencyMs / Availability7d 由 List handler 在批量聚合后填充。
}
if m.LastCheckedAt != nil {
@@ -279,16 +301,20 @@ func (h *ChannelMonitorHandler) Create(c *gin.Context) {
}
m, err := h.monitorService.Create(c.Request.Context(), service.ChannelMonitorCreateParams{
- Name: req.Name,
- Provider: req.Provider,
- Endpoint: req.Endpoint,
- APIKey: req.APIKey,
- PrimaryModel: req.PrimaryModel,
- ExtraModels: req.ExtraModels,
- GroupName: req.GroupName,
- Enabled: enabled,
- IntervalSeconds: req.IntervalSeconds,
- CreatedBy: subject.UserID,
+ Name: req.Name,
+ Provider: req.Provider,
+ Endpoint: req.Endpoint,
+ APIKey: req.APIKey,
+ PrimaryModel: req.PrimaryModel,
+ ExtraModels: req.ExtraModels,
+ GroupName: req.GroupName,
+ Enabled: enabled,
+ IntervalSeconds: req.IntervalSeconds,
+ CreatedBy: subject.UserID,
+ TemplateID: req.TemplateID,
+ ExtraHeaders: req.ExtraHeaders,
+ BodyOverrideMode: req.BodyOverrideMode,
+ BodyOverride: req.BodyOverride,
})
if err != nil {
response.ErrorFrom(c, err)
@@ -310,15 +336,20 @@ func (h *ChannelMonitorHandler) Update(c *gin.Context) {
}
m, err := h.monitorService.Update(c.Request.Context(), id, service.ChannelMonitorUpdateParams{
- Name: req.Name,
- Provider: req.Provider,
- Endpoint: req.Endpoint,
- APIKey: req.APIKey,
- PrimaryModel: req.PrimaryModel,
- ExtraModels: req.ExtraModels,
- GroupName: req.GroupName,
- Enabled: req.Enabled,
- IntervalSeconds: req.IntervalSeconds,
+ Name: req.Name,
+ Provider: req.Provider,
+ Endpoint: req.Endpoint,
+ APIKey: req.APIKey,
+ PrimaryModel: req.PrimaryModel,
+ ExtraModels: req.ExtraModels,
+ GroupName: req.GroupName,
+ Enabled: req.Enabled,
+ IntervalSeconds: req.IntervalSeconds,
+ TemplateID: req.TemplateID,
+ ClearTemplate: req.ClearTemplate,
+ ExtraHeaders: req.ExtraHeaders,
+ BodyOverrideMode: req.BodyOverrideMode,
+ BodyOverride: req.BodyOverride,
})
if err != nil {
response.ErrorFrom(c, err)
diff --git a/backend/internal/handler/admin/channel_monitor_template_handler.go b/backend/internal/handler/admin/channel_monitor_template_handler.go
new file mode 100644
index 00000000..8c1191ea
--- /dev/null
+++ b/backend/internal/handler/admin/channel_monitor_template_handler.go
@@ -0,0 +1,195 @@
+package admin
+
+import (
+ "strconv"
+ "strings"
+ "time"
+
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+
+ "github.com/gin-gonic/gin"
+)
+
+// ChannelMonitorRequestTemplateHandler 请求模板管理后台 handler。
+type ChannelMonitorRequestTemplateHandler struct {
+ templateService *service.ChannelMonitorRequestTemplateService
+}
+
+// NewChannelMonitorRequestTemplateHandler 创建 handler。
+func NewChannelMonitorRequestTemplateHandler(templateService *service.ChannelMonitorRequestTemplateService) *ChannelMonitorRequestTemplateHandler {
+ return &ChannelMonitorRequestTemplateHandler{templateService: templateService}
+}
+
+// --- DTO ---
+
+type channelMonitorTemplateCreateRequest struct {
+ Name string `json:"name" binding:"required,max=100"`
+ Provider string `json:"provider" binding:"required,oneof=openai anthropic gemini"`
+ Description string `json:"description" binding:"max=500"`
+ ExtraHeaders map[string]string `json:"extra_headers"`
+ BodyOverrideMode string `json:"body_override_mode" binding:"omitempty,oneof=off merge replace"`
+ BodyOverride map[string]any `json:"body_override"`
+}
+
+type channelMonitorTemplateUpdateRequest struct {
+ Name *string `json:"name" binding:"omitempty,max=100"`
+ Description *string `json:"description" binding:"omitempty,max=500"`
+ ExtraHeaders *map[string]string `json:"extra_headers"`
+ BodyOverrideMode *string `json:"body_override_mode" binding:"omitempty,oneof=off merge replace"`
+ BodyOverride *map[string]any `json:"body_override"`
+}
+
+type channelMonitorTemplateResponse struct {
+ ID int64 `json:"id"`
+ Name string `json:"name"`
+ Provider string `json:"provider"`
+ Description string `json:"description"`
+ ExtraHeaders map[string]string `json:"extra_headers"`
+ BodyOverrideMode string `json:"body_override_mode"`
+ BodyOverride map[string]any `json:"body_override"`
+ CreatedAt string `json:"created_at"`
+ UpdatedAt string `json:"updated_at"`
+ AssociatedMonitors int64 `json:"associated_monitors"`
+}
+
+func (h *ChannelMonitorRequestTemplateHandler) toResponse(c *gin.Context, t *service.ChannelMonitorRequestTemplate) *channelMonitorTemplateResponse {
+ if t == nil {
+ return nil
+ }
+ headers := t.ExtraHeaders
+ if headers == nil {
+ headers = map[string]string{}
+ }
+ count, _ := h.templateService.CountAssociatedMonitors(c.Request.Context(), t.ID)
+ return &channelMonitorTemplateResponse{
+ ID: t.ID,
+ Name: t.Name,
+ Provider: t.Provider,
+ Description: t.Description,
+ ExtraHeaders: headers,
+ BodyOverrideMode: t.BodyOverrideMode,
+ BodyOverride: t.BodyOverride,
+ CreatedAt: t.CreatedAt.UTC().Format(time.RFC3339),
+ UpdatedAt: t.UpdatedAt.UTC().Format(time.RFC3339),
+ AssociatedMonitors: count,
+ }
+}
+
+// parseTemplateID 提取并校验 :id。
+func parseTemplateID(c *gin.Context) (int64, bool) {
+ id, err := strconv.ParseInt(c.Param("id"), 10, 64)
+ if err != nil || id <= 0 {
+ response.ErrorFrom(c, infraerrors.BadRequest("INVALID_TEMPLATE_ID", "invalid template id"))
+ return 0, false
+ }
+ return id, true
+}
+
+// --- Handlers ---
+
+// List GET /api/v1/admin/channel-monitor-templates?provider=anthropic
+func (h *ChannelMonitorRequestTemplateHandler) List(c *gin.Context) {
+ items, err := h.templateService.List(c.Request.Context(), service.ChannelMonitorRequestTemplateListParams{
+ Provider: strings.TrimSpace(c.Query("provider")),
+ })
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ out := make([]*channelMonitorTemplateResponse, 0, len(items))
+ for _, t := range items {
+ out = append(out, h.toResponse(c, t))
+ }
+ response.Success(c, gin.H{"items": out})
+}
+
+// Get GET /api/v1/admin/channel-monitor-templates/:id
+func (h *ChannelMonitorRequestTemplateHandler) Get(c *gin.Context) {
+ id, ok := parseTemplateID(c)
+ if !ok {
+ return
+ }
+ t, err := h.templateService.Get(c.Request.Context(), id)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, h.toResponse(c, t))
+}
+
+// Create POST /api/v1/admin/channel-monitor-templates
+func (h *ChannelMonitorRequestTemplateHandler) Create(c *gin.Context) {
+ var req channelMonitorTemplateCreateRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
+ return
+ }
+ t, err := h.templateService.Create(c.Request.Context(), service.ChannelMonitorRequestTemplateCreateParams{
+ Name: req.Name,
+ Provider: req.Provider,
+ Description: req.Description,
+ ExtraHeaders: req.ExtraHeaders,
+ BodyOverrideMode: req.BodyOverrideMode,
+ BodyOverride: req.BodyOverride,
+ })
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Created(c, h.toResponse(c, t))
+}
+
+// Update PUT /api/v1/admin/channel-monitor-templates/:id
+func (h *ChannelMonitorRequestTemplateHandler) Update(c *gin.Context) {
+ id, ok := parseTemplateID(c)
+ if !ok {
+ return
+ }
+ var req channelMonitorTemplateUpdateRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
+ return
+ }
+ t, err := h.templateService.Update(c.Request.Context(), id, service.ChannelMonitorRequestTemplateUpdateParams{
+ Name: req.Name,
+ Description: req.Description,
+ ExtraHeaders: req.ExtraHeaders,
+ BodyOverrideMode: req.BodyOverrideMode,
+ BodyOverride: req.BodyOverride,
+ })
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, h.toResponse(c, t))
+}
+
+// Delete DELETE /api/v1/admin/channel-monitor-templates/:id
+func (h *ChannelMonitorRequestTemplateHandler) Delete(c *gin.Context) {
+ id, ok := parseTemplateID(c)
+ if !ok {
+ return
+ }
+ if err := h.templateService.Delete(c.Request.Context(), id); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, nil)
+}
+
+// Apply POST /api/v1/admin/channel-monitor-templates/:id/apply
+// 一键把模板当前配置覆盖到所有关联监控上。
+func (h *ChannelMonitorRequestTemplateHandler) Apply(c *gin.Context) {
+ id, ok := parseTemplateID(c)
+ if !ok {
+ return
+ }
+ affected, err := h.templateService.ApplyToMonitors(c.Request.Context(), id)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, gin.H{"affected": affected})
+}
diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go
index 58480c93..bedb81ae 100644
--- a/backend/internal/handler/handler.go
+++ b/backend/internal/handler/handler.go
@@ -6,33 +6,34 @@ import (
// AdminHandlers contains all admin-related HTTP handlers
type AdminHandlers struct {
- Dashboard *admin.DashboardHandler
- User *admin.UserHandler
- Group *admin.GroupHandler
- Account *admin.AccountHandler
- Announcement *admin.AnnouncementHandler
- DataManagement *admin.DataManagementHandler
- Backup *admin.BackupHandler
- OAuth *admin.OAuthHandler
- OpenAIOAuth *admin.OpenAIOAuthHandler
- GeminiOAuth *admin.GeminiOAuthHandler
- AntigravityOAuth *admin.AntigravityOAuthHandler
- Proxy *admin.ProxyHandler
- Redeem *admin.RedeemHandler
- Promo *admin.PromoHandler
- Setting *admin.SettingHandler
- Ops *admin.OpsHandler
- System *admin.SystemHandler
- Subscription *admin.SubscriptionHandler
- Usage *admin.UsageHandler
- UserAttribute *admin.UserAttributeHandler
- ErrorPassthrough *admin.ErrorPassthroughHandler
- TLSFingerprintProfile *admin.TLSFingerprintProfileHandler
- APIKey *admin.AdminAPIKeyHandler
- ScheduledTest *admin.ScheduledTestHandler
- Channel *admin.ChannelHandler
- ChannelMonitor *admin.ChannelMonitorHandler
- Payment *admin.PaymentHandler
+ Dashboard *admin.DashboardHandler
+ User *admin.UserHandler
+ Group *admin.GroupHandler
+ Account *admin.AccountHandler
+ Announcement *admin.AnnouncementHandler
+ DataManagement *admin.DataManagementHandler
+ Backup *admin.BackupHandler
+ OAuth *admin.OAuthHandler
+ OpenAIOAuth *admin.OpenAIOAuthHandler
+ GeminiOAuth *admin.GeminiOAuthHandler
+ AntigravityOAuth *admin.AntigravityOAuthHandler
+ Proxy *admin.ProxyHandler
+ Redeem *admin.RedeemHandler
+ Promo *admin.PromoHandler
+ Setting *admin.SettingHandler
+ Ops *admin.OpsHandler
+ System *admin.SystemHandler
+ Subscription *admin.SubscriptionHandler
+ Usage *admin.UsageHandler
+ UserAttribute *admin.UserAttributeHandler
+ ErrorPassthrough *admin.ErrorPassthroughHandler
+ TLSFingerprintProfile *admin.TLSFingerprintProfileHandler
+ APIKey *admin.AdminAPIKeyHandler
+ ScheduledTest *admin.ScheduledTestHandler
+ Channel *admin.ChannelHandler
+ ChannelMonitor *admin.ChannelMonitorHandler
+ ChannelMonitorTemplate *admin.ChannelMonitorRequestTemplateHandler
+ Payment *admin.PaymentHandler
}
// Handlers contains all HTTP handlers
diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go
index 7c1a5d1b..6584eb70 100644
--- a/backend/internal/handler/wire.go
+++ b/backend/internal/handler/wire.go
@@ -35,36 +35,38 @@ func ProvideAdminHandlers(
scheduledTestHandler *admin.ScheduledTestHandler,
channelHandler *admin.ChannelHandler,
channelMonitorHandler *admin.ChannelMonitorHandler,
+ channelMonitorTemplateHandler *admin.ChannelMonitorRequestTemplateHandler,
paymentHandler *admin.PaymentHandler,
) *AdminHandlers {
return &AdminHandlers{
- Dashboard: dashboardHandler,
- User: userHandler,
- Group: groupHandler,
- Account: accountHandler,
- Announcement: announcementHandler,
- DataManagement: dataManagementHandler,
- Backup: backupHandler,
- OAuth: oauthHandler,
- OpenAIOAuth: openaiOAuthHandler,
- GeminiOAuth: geminiOAuthHandler,
- AntigravityOAuth: antigravityOAuthHandler,
- Proxy: proxyHandler,
- Redeem: redeemHandler,
- Promo: promoHandler,
- Setting: settingHandler,
- Ops: opsHandler,
- System: systemHandler,
- Subscription: subscriptionHandler,
- Usage: usageHandler,
- UserAttribute: userAttributeHandler,
- ErrorPassthrough: errorPassthroughHandler,
- TLSFingerprintProfile: tlsFingerprintProfileHandler,
- APIKey: apiKeyHandler,
- ScheduledTest: scheduledTestHandler,
- Channel: channelHandler,
- ChannelMonitor: channelMonitorHandler,
- Payment: paymentHandler,
+ Dashboard: dashboardHandler,
+ User: userHandler,
+ Group: groupHandler,
+ Account: accountHandler,
+ Announcement: announcementHandler,
+ DataManagement: dataManagementHandler,
+ Backup: backupHandler,
+ OAuth: oauthHandler,
+ OpenAIOAuth: openaiOAuthHandler,
+ GeminiOAuth: geminiOAuthHandler,
+ AntigravityOAuth: antigravityOAuthHandler,
+ Proxy: proxyHandler,
+ Redeem: redeemHandler,
+ Promo: promoHandler,
+ Setting: settingHandler,
+ Ops: opsHandler,
+ System: systemHandler,
+ Subscription: subscriptionHandler,
+ Usage: usageHandler,
+ UserAttribute: userAttributeHandler,
+ ErrorPassthrough: errorPassthroughHandler,
+ TLSFingerprintProfile: tlsFingerprintProfileHandler,
+ APIKey: apiKeyHandler,
+ ScheduledTest: scheduledTestHandler,
+ Channel: channelHandler,
+ ChannelMonitor: channelMonitorHandler,
+ ChannelMonitorTemplate: channelMonitorTemplateHandler,
+ Payment: paymentHandler,
}
}
@@ -162,6 +164,7 @@ var ProviderSet = wire.NewSet(
admin.NewScheduledTestHandler,
admin.NewChannelHandler,
admin.NewChannelMonitorHandler,
+ admin.NewChannelMonitorRequestTemplateHandler,
admin.NewPaymentHandler,
// AdminHandlers and Handlers constructors
diff --git a/backend/internal/repository/channel_monitor_repo.go b/backend/internal/repository/channel_monitor_repo.go
index f4e2a0ec..67dccd6c 100644
--- a/backend/internal/repository/channel_monitor_repo.go
+++ b/backend/internal/repository/channel_monitor_repo.go
@@ -44,7 +44,15 @@ func (r *channelMonitorRepository) Create(ctx context.Context, m *service.Channe
SetGroupName(m.GroupName).
SetEnabled(m.Enabled).
SetIntervalSeconds(m.IntervalSeconds).
- SetCreatedBy(m.CreatedBy)
+ SetCreatedBy(m.CreatedBy).
+ SetExtraHeaders(emptyHeadersIfNilRepo(m.ExtraHeaders)).
+ SetBodyOverrideMode(defaultBodyModeRepo(m.BodyOverrideMode))
+ if m.TemplateID != nil {
+ builder = builder.SetTemplateID(*m.TemplateID)
+ }
+ if m.BodyOverride != nil {
+ builder = builder.SetBodyOverride(m.BodyOverride)
+ }
created, err := builder.Save(ctx)
if err != nil {
@@ -77,7 +85,19 @@ func (r *channelMonitorRepository) Update(ctx context.Context, m *service.Channe
SetExtraModels(emptySliceIfNil(m.ExtraModels)).
SetGroupName(m.GroupName).
SetEnabled(m.Enabled).
- SetIntervalSeconds(m.IntervalSeconds)
+ SetIntervalSeconds(m.IntervalSeconds).
+ SetExtraHeaders(emptyHeadersIfNilRepo(m.ExtraHeaders)).
+ SetBodyOverrideMode(defaultBodyModeRepo(m.BodyOverrideMode))
+ if m.TemplateID != nil {
+ updater = updater.SetTemplateID(*m.TemplateID)
+ } else {
+ updater = updater.ClearTemplateID()
+ }
+ if m.BodyOverride != nil {
+ updater = updater.SetBodyOverride(m.BodyOverride)
+ } else {
+ updater = updater.ClearBodyOverride()
+ }
updated, err := updater.Save(ctx)
if err != nil {
@@ -716,22 +736,51 @@ func entToServiceMonitor(row *dbent.ChannelMonitor) *service.ChannelMonitor {
if extras == nil {
extras = []string{}
}
- return &service.ChannelMonitor{
- ID: row.ID,
- Name: row.Name,
- Provider: string(row.Provider),
- Endpoint: row.Endpoint,
- APIKey: row.APIKeyEncrypted, // 仍为密文,service 层负责解密
- PrimaryModel: row.PrimaryModel,
- ExtraModels: extras,
- GroupName: row.GroupName,
- Enabled: row.Enabled,
- IntervalSeconds: row.IntervalSeconds,
- LastCheckedAt: row.LastCheckedAt,
- CreatedBy: row.CreatedBy,
- CreatedAt: row.CreatedAt,
- UpdatedAt: row.UpdatedAt,
+ headers := row.ExtraHeaders
+ if headers == nil {
+ headers = map[string]string{}
+ }
+ out := &service.ChannelMonitor{
+ ID: row.ID,
+ Name: row.Name,
+ Provider: string(row.Provider),
+ Endpoint: row.Endpoint,
+ APIKey: row.APIKeyEncrypted, // 仍为密文,service 层负责解密
+ PrimaryModel: row.PrimaryModel,
+ ExtraModels: extras,
+ GroupName: row.GroupName,
+ Enabled: row.Enabled,
+ IntervalSeconds: row.IntervalSeconds,
+ LastCheckedAt: row.LastCheckedAt,
+ CreatedBy: row.CreatedBy,
+ CreatedAt: row.CreatedAt,
+ UpdatedAt: row.UpdatedAt,
+ ExtraHeaders: headers,
+ BodyOverrideMode: row.BodyOverrideMode,
+ BodyOverride: row.BodyOverride,
+ }
+ if row.TemplateID != nil {
+ id := *row.TemplateID
+ out.TemplateID = &id
+ }
+ return out
+}
+
+// emptyHeadersIfNilRepo 与 service.emptyHeadersIfNil 功能一致,
+// repo 独立一份避免 import 循环。
+func emptyHeadersIfNilRepo(h map[string]string) map[string]string {
+ if h == nil {
+ return map[string]string{}
+ }
+ return h
+}
+
+// defaultBodyModeRepo 空串归一为 off(同上不循环)。
+func defaultBodyModeRepo(mode string) string {
+ if mode == "" {
+ return "off"
}
+ return mode
}
func emptySliceIfNil(in []string) []string {
diff --git a/backend/internal/repository/channel_monitor_template_repo.go b/backend/internal/repository/channel_monitor_template_repo.go
new file mode 100644
index 00000000..03f3692b
--- /dev/null
+++ b/backend/internal/repository/channel_monitor_template_repo.go
@@ -0,0 +1,168 @@
+package repository
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+)
+
+// channelMonitorRequestTemplateRepository 实现 service.ChannelMonitorRequestTemplateRepository。
+// 与 channelMonitorRepository 分开一个文件,职责清晰。
+type channelMonitorRequestTemplateRepository struct {
+ client *dbent.Client
+ db *sql.DB
+}
+
+// NewChannelMonitorRequestTemplateRepository 创建模板仓储实例。
+func NewChannelMonitorRequestTemplateRepository(client *dbent.Client, db *sql.DB) service.ChannelMonitorRequestTemplateRepository {
+ return &channelMonitorRequestTemplateRepository{client: client, db: db}
+}
+
+// ---------- CRUD ----------
+
+func (r *channelMonitorRequestTemplateRepository) Create(ctx context.Context, t *service.ChannelMonitorRequestTemplate) error {
+ client := clientFromContext(ctx, r.client)
+ builder := client.ChannelMonitorRequestTemplate.Create().
+ SetName(t.Name).
+ SetProvider(channelmonitorrequesttemplate.Provider(t.Provider)).
+ SetDescription(t.Description).
+ SetExtraHeaders(emptyHeadersIfNilRepo(t.ExtraHeaders)).
+ SetBodyOverrideMode(defaultBodyModeRepo(t.BodyOverrideMode))
+ if t.BodyOverride != nil {
+ builder = builder.SetBodyOverride(t.BodyOverride)
+ }
+
+ created, err := builder.Save(ctx)
+ if err != nil {
+ return translatePersistenceError(err, service.ErrChannelMonitorTemplateNotFound, nil)
+ }
+ t.ID = created.ID
+ t.CreatedAt = created.CreatedAt
+ t.UpdatedAt = created.UpdatedAt
+ return nil
+}
+
+func (r *channelMonitorRequestTemplateRepository) GetByID(ctx context.Context, id int64) (*service.ChannelMonitorRequestTemplate, error) {
+ row, err := r.client.ChannelMonitorRequestTemplate.Query().
+ Where(channelmonitorrequesttemplate.IDEQ(id)).
+ Only(ctx)
+ if err != nil {
+ return nil, translatePersistenceError(err, service.ErrChannelMonitorTemplateNotFound, nil)
+ }
+ return entToServiceTemplate(row), nil
+}
+
+func (r *channelMonitorRequestTemplateRepository) Update(ctx context.Context, t *service.ChannelMonitorRequestTemplate) error {
+ client := clientFromContext(ctx, r.client)
+ updater := client.ChannelMonitorRequestTemplate.UpdateOneID(t.ID).
+ SetName(t.Name).
+ SetDescription(t.Description).
+ SetExtraHeaders(emptyHeadersIfNilRepo(t.ExtraHeaders)).
+ SetBodyOverrideMode(defaultBodyModeRepo(t.BodyOverrideMode))
+ if t.BodyOverride != nil {
+ updater = updater.SetBodyOverride(t.BodyOverride)
+ } else {
+ updater = updater.ClearBodyOverride()
+ }
+ updated, err := updater.Save(ctx)
+ if err != nil {
+ return translatePersistenceError(err, service.ErrChannelMonitorTemplateNotFound, nil)
+ }
+ t.UpdatedAt = updated.UpdatedAt
+ return nil
+}
+
+func (r *channelMonitorRequestTemplateRepository) Delete(ctx context.Context, id int64) error {
+ client := clientFromContext(ctx, r.client)
+ if err := client.ChannelMonitorRequestTemplate.DeleteOneID(id).Exec(ctx); err != nil {
+ return translatePersistenceError(err, service.ErrChannelMonitorTemplateNotFound, nil)
+ }
+ return nil
+}
+
+func (r *channelMonitorRequestTemplateRepository) List(ctx context.Context, params service.ChannelMonitorRequestTemplateListParams) ([]*service.ChannelMonitorRequestTemplate, error) {
+ q := r.client.ChannelMonitorRequestTemplate.Query()
+ if params.Provider != "" {
+ q = q.Where(channelmonitorrequesttemplate.ProviderEQ(channelmonitorrequesttemplate.Provider(params.Provider)))
+ }
+ rows, err := q.
+ Order(dbent.Asc(channelmonitorrequesttemplate.FieldProvider), dbent.Asc(channelmonitorrequesttemplate.FieldName)).
+ All(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("list monitor templates: %w", err)
+ }
+ out := make([]*service.ChannelMonitorRequestTemplate, 0, len(rows))
+ for _, row := range rows {
+ out = append(out, entToServiceTemplate(row))
+ }
+ return out, nil
+}
+
+// ApplyToMonitors 把模板当前配置批量覆盖到 template_id = id 的监控上。
+//
+// 用一条 UPDATE 完成:extra_headers / body_override_mode / body_override 都覆盖。
+// 走 ent 的 UpdateMany 保证走 ent hooks;走原生 SQL 也可以但 ent jsonb 序列化更省心。
+func (r *channelMonitorRequestTemplateRepository) ApplyToMonitors(ctx context.Context, id int64) (int64, error) {
+ client := clientFromContext(ctx, r.client)
+ tpl, err := client.ChannelMonitorRequestTemplate.Query().
+ Where(channelmonitorrequesttemplate.IDEQ(id)).
+ Only(ctx)
+ if err != nil {
+ return 0, translatePersistenceError(err, service.ErrChannelMonitorTemplateNotFound, nil)
+ }
+
+ updater := client.ChannelMonitor.Update().
+ Where(channelmonitor.TemplateIDEQ(id)).
+ SetExtraHeaders(emptyHeadersIfNilRepo(tpl.ExtraHeaders)).
+ SetBodyOverrideMode(defaultBodyModeRepo(tpl.BodyOverrideMode))
+ if tpl.BodyOverride != nil {
+ updater = updater.SetBodyOverride(tpl.BodyOverride)
+ } else {
+ updater = updater.ClearBodyOverride()
+ }
+
+ affected, err := updater.Save(ctx)
+ if err != nil {
+ return 0, fmt.Errorf("apply template to monitors: %w", err)
+ }
+ return int64(affected), nil
+}
+
+// CountAssociatedMonitors 统计关联监控数(UI 展示「N 个配置」用)。
+func (r *channelMonitorRequestTemplateRepository) CountAssociatedMonitors(ctx context.Context, id int64) (int64, error) {
+ count, err := r.client.ChannelMonitor.Query().
+ Where(channelmonitor.TemplateIDEQ(id)).
+ Count(ctx)
+ if err != nil {
+ return 0, fmt.Errorf("count monitors for template %d: %w", id, err)
+ }
+ return int64(count), nil
+}
+
+// ---------- helpers ----------
+
+func entToServiceTemplate(row *dbent.ChannelMonitorRequestTemplate) *service.ChannelMonitorRequestTemplate {
+ if row == nil {
+ return nil
+ }
+ headers := row.ExtraHeaders
+ if headers == nil {
+ headers = map[string]string{}
+ }
+ return &service.ChannelMonitorRequestTemplate{
+ ID: row.ID,
+ Name: row.Name,
+ Provider: string(row.Provider),
+ Description: row.Description,
+ ExtraHeaders: headers,
+ BodyOverrideMode: row.BodyOverrideMode,
+ BodyOverride: row.BodyOverride,
+ CreatedAt: row.CreatedAt,
+ UpdatedAt: row.UpdatedAt,
+ }
+}
diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go
index 7427cd04..b1d5e36a 100644
--- a/backend/internal/repository/wire.go
+++ b/backend/internal/repository/wire.go
@@ -90,6 +90,7 @@ var ProviderSet = wire.NewSet(
NewTLSFingerprintProfileRepository,
NewChannelRepository,
NewChannelMonitorRepository,
+ NewChannelMonitorRequestTemplateRepository,
// Cache implementations
NewGatewayCache,
diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go
index 0381dc57..13cecd59 100644
--- a/backend/internal/server/routes/admin.go
+++ b/backend/internal/server/routes/admin.go
@@ -579,4 +579,14 @@ func registerChannelMonitorRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
monitors.POST("/:id/run", h.Admin.ChannelMonitor.Run)
monitors.GET("/:id/history", h.Admin.ChannelMonitor.History)
}
+
+ templates := admin.Group("/channel-monitor-templates")
+ {
+ templates.GET("", h.Admin.ChannelMonitorTemplate.List)
+ templates.POST("", h.Admin.ChannelMonitorTemplate.Create)
+ templates.GET("/:id", h.Admin.ChannelMonitorTemplate.Get)
+ templates.PUT("/:id", h.Admin.ChannelMonitorTemplate.Update)
+ templates.DELETE("/:id", h.Admin.ChannelMonitorTemplate.Delete)
+ templates.POST("/:id/apply", h.Admin.ChannelMonitorTemplate.Apply)
+ }
}
diff --git a/backend/internal/service/channel_monitor_checker.go b/backend/internal/service/channel_monitor_checker.go
index e03c2e3a..33570629 100644
--- a/backend/internal/service/channel_monitor_checker.go
+++ b/backend/internal/service/channel_monitor_checker.go
@@ -37,9 +37,23 @@ func newSSRFSafeHTTPClient(timeout time.Duration) *http.Client {
return &http.Client{Timeout: timeout, Transport: tr}
}
+// CheckOptions 承载一次检测的自定义入参。
+// 所有字段都是可选(零值即等价于"用默认行为")。
+type CheckOptions struct {
+ // ExtraHeaders 用户自定义 HTTP 头(merge 到 adapter 默认 headers,用户优先)。
+ ExtraHeaders map[string]string
+ // BodyOverrideMode: off | merge | replace
+ BodyOverrideMode string
+ // BodyOverride 在 merge 模式下做浅合并(key 命中黑名单时静默丢弃),
+ // 在 replace 模式下直接当作完整 body。
+ BodyOverride map[string]any
+}
+
// runCheckForModel 对单个 (provider, model) 做一次完整检测。
// 不返回 error:所有失败都包装进 CheckResult.Status=error/failed。
-func runCheckForModel(ctx context.Context, provider, endpoint, apiKey, model string) *CheckResult {
+//
+// opts 承载模板 / 监控快照带来的自定义配置。nil 等同于 "off + 无 extra headers"。
+func runCheckForModel(ctx context.Context, provider, endpoint, apiKey, model string, opts *CheckOptions) *CheckResult {
res := &CheckResult{
Model: model,
Status: MonitorStatusError,
@@ -47,9 +61,10 @@ func runCheckForModel(ctx context.Context, provider, endpoint, apiKey, model str
}
challenge := generateChallenge()
+ mode := bodyOverrideMode(opts)
start := time.Now()
- respText, rawBody, statusCode, err := callProvider(ctx, provider, endpoint, apiKey, model, challenge.Prompt)
+ respText, rawBody, statusCode, err := callProvider(ctx, provider, endpoint, apiKey, model, challenge.Prompt, opts)
latency := time.Since(start)
latencyMs := int(latency / time.Millisecond)
res.LatencyMs = &latencyMs
@@ -68,22 +83,47 @@ func runCheckForModel(ctx context.Context, provider, endpoint, apiKey, model str
return res
}
+ // Replace 模式:跳过 challenge 校验(用户 body 是静态的,challenge 没法嵌入)。
+ // 改用「HTTP 2xx + 响应文本(adapter.textPath 抽取)非空」作为 operational 判定。
+ // 响应文本为空则降级为 failed(视为上游回了 200 但没实际内容)。
+ if mode == MonitorBodyOverrideModeReplace {
+ if strings.TrimSpace(respText) == "" {
+ res.Status = MonitorStatusFailed
+ res.Message = truncateMessage("replace-mode: upstream returned 2xx with empty text")
+ return res
+ }
+ return finalizeOperationalOrDegraded(res, latency, latencyMs)
+ }
+
if !validateChallenge(respText, challenge.Expected) {
res.Status = MonitorStatusFailed
res.Message = truncateMessage(sanitizeErrorMessage(fmt.Sprintf("challenge mismatch (expected %s, got %q)", challenge.Expected, respText)))
return res
}
+ return finalizeOperationalOrDegraded(res, latency, latencyMs)
+}
+
+// finalizeOperationalOrDegraded 负责走到最后一步的 operational/degraded 判定。
+// 拆出来是为了让 runCheckForModel 不超过 30 行。
+func finalizeOperationalOrDegraded(res *CheckResult, latency time.Duration, latencyMs int) *CheckResult {
if latency >= monitorDegradedThreshold {
res.Status = MonitorStatusDegraded
res.Message = truncateMessage(fmt.Sprintf("slow response: %dms", latencyMs))
return res
}
-
res.Status = MonitorStatusOperational
return res
}
+// bodyOverrideMode 归一取 opts.BodyOverrideMode,nil opts / 空串都视为 off。
+func bodyOverrideMode(opts *CheckOptions) string {
+ if opts == nil || opts.BodyOverrideMode == "" {
+ return MonitorBodyOverrideModeOff
+ }
+ return opts.BodyOverrideMode
+}
+
// pingEndpointOrigin 对 endpoint 的 origin (scheme://host) 发起 HEAD 请求,返回耗时。
// 失败时返回 nil(不影响主状态判定)。
func pingEndpointOrigin(ctx context.Context, endpoint string) *int {
@@ -183,29 +223,109 @@ func isSupportedProvider(p string) bool {
}
// callProvider 通过 providerAdapters 分发到具体实现。
+// opts 承载用户的自定义 headers / body 覆盖(可为 nil)。
//
// 返回值:
// - extractedText: 按 textPath 抽出的成功文本,仅在 status 2xx 时有意义;非 2xx 时通常为空串
// - rawBody: 完整响应体的字符串形式(已被 monitorResponseMaxBytes 截断),用于错误路径保留上游真实回包
// - status: HTTP 状态码
// - err: 网络 / 序列化错误
-func callProvider(ctx context.Context, provider, endpoint, apiKey, model, prompt string) (extractedText, rawBody string, status int, err error) {
+func callProvider(ctx context.Context, provider, endpoint, apiKey, model, prompt string, opts *CheckOptions) (extractedText, rawBody string, status int, err error) {
adapter, ok := providerAdapters[provider]
if !ok {
return "", "", 0, fmt.Errorf("unsupported provider %q", provider)
}
- body, err := adapter.buildBody(model, prompt)
+ body, err := buildRequestBody(adapter, provider, model, prompt, opts)
if err != nil {
- return "", "", 0, fmt.Errorf("marshal body: %w", err)
+ return "", "", 0, err
}
+ headers := mergeHeaders(adapter.buildHeaders(apiKey), opts)
full := joinURL(endpoint, adapter.buildPath(model))
- respBytes, status, err := postRawJSON(ctx, full, body, adapter.buildHeaders(apiKey))
+ respBytes, status, err := postRawJSON(ctx, full, body, headers)
if err != nil {
return "", "", status, err
}
return gjson.GetBytes(respBytes, adapter.textPath).String(), string(respBytes), status, nil
}
+// mergeHeaders 把用户自定义 headers 合并到 adapter 默认 headers 上。
+// 用户值覆盖默认;命中黑名单(hop-by-hop / 由 http.Client 自管的)的 key 静默丢弃。
+func mergeHeaders(base map[string]string, opts *CheckOptions) map[string]string {
+ if opts == nil || len(opts.ExtraHeaders) == 0 {
+ return base
+ }
+ out := make(map[string]string, len(base)+len(opts.ExtraHeaders))
+ for k, v := range base {
+ out[k] = v
+ }
+ for k, v := range opts.ExtraHeaders {
+ if IsForbiddenHeaderName(k) {
+ continue
+ }
+ out[k] = v
+ }
+ return out
+}
+
+// buildRequestBody 根据 body_override_mode 构造请求 body。
+//
+// - off: adapter 默认 body
+// - merge: adapter 默认 body 与 BodyOverride 浅合并;BodyOverride 中命中
+// bodyMergeKeyDenyList[provider] 的 key 会被静默丢弃,避免破坏 challenge / model 路由
+// - replace: 直接 marshal BodyOverride 作为完整 body
+//
+// 任何 mode 返回的 []byte 都已经是合法 JSON,可直接送入 postRawJSON。
+func buildRequestBody(adapter providerAdapter, provider, model, prompt string, opts *CheckOptions) ([]byte, error) {
+ mode := bodyOverrideMode(opts)
+
+ if mode == MonitorBodyOverrideModeReplace {
+ if opts == nil || len(opts.BodyOverride) == 0 {
+ return nil, fmt.Errorf("replace mode: body_override is empty")
+ }
+ body, err := json.Marshal(opts.BodyOverride)
+ if err != nil {
+ return nil, fmt.Errorf("marshal body_override (replace): %w", err)
+ }
+ return body, nil
+ }
+
+ defaultBody, err := adapter.buildBody(model, prompt)
+ if err != nil {
+ return nil, fmt.Errorf("marshal default body: %w", err)
+ }
+ if mode != MonitorBodyOverrideModeMerge || opts == nil || len(opts.BodyOverride) == 0 {
+ return defaultBody, nil
+ }
+
+ var defaultMap map[string]any
+ if err := json.Unmarshal(defaultBody, &defaultMap); err != nil {
+ return nil, fmt.Errorf("unmarshal default body for merge: %w", err)
+ }
+ deny := bodyMergeKeyDenyList[provider]
+ for k, v := range opts.BodyOverride {
+ if deny[k] {
+ continue
+ }
+ defaultMap[k] = v
+ }
+ merged, err := json.Marshal(defaultMap)
+ if err != nil {
+ return nil, fmt.Errorf("marshal merged body: %w", err)
+ }
+ return merged, nil
+}
+
+// bodyMergeKeyDenyList 在 merge 模式下,禁止用户覆盖这些 provider-specific 的关键字段。
+// 思路抄 check-cx 的 EXCLUDED_METADATA_KEYS:保护 challenge / model 路由不被用户误伤。
+// 用户想动这些字段就用 replace 模式(已知会跳 challenge 校验)。
+//
+//nolint:gochecknoglobals // 静态查表,初始化后不变。
+var bodyMergeKeyDenyList = map[string]map[string]bool{
+ MonitorProviderOpenAI: {"model": true, "messages": true, "stream": true},
+ MonitorProviderAnthropic: {"model": true, "messages": true},
+ MonitorProviderGemini: {"contents": true},
+}
+
// postRawJSON 发送 POST + 已序列化好的 JSON 字节,限制响应体大小,返回响应字节、HTTP status、错误。
// adapter 自行 marshal 是为了精确控制字段顺序与类型,所以这里直接收 []byte 而不是 any。
func postRawJSON(ctx context.Context, fullURL string, payload []byte, headers map[string]string) ([]byte, int, error) {
diff --git a/backend/internal/service/channel_monitor_checker_body_test.go b/backend/internal/service/channel_monitor_checker_body_test.go
new file mode 100644
index 00000000..323cf8b7
--- /dev/null
+++ b/backend/internal/service/channel_monitor_checker_body_test.go
@@ -0,0 +1,173 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+ "time"
+)
+
+// swapMonitorHTTPClient 临时替换 monitorHTTPClient 为不带 SSRF 校验的普通 client,
+// 让 httptest (127.0.0.1) 能连通。测试结束后恢复。
+func swapMonitorHTTPClient(t *testing.T) {
+ t.Helper()
+ orig := monitorHTTPClient
+ monitorHTTPClient = &http.Client{Timeout: 5 * time.Second}
+ t.Cleanup(func() { monitorHTTPClient = orig })
+}
+
+// captureHandler 把每次收到的请求 body 和 headers 存起来,测试断言用。
+type captureHandler struct {
+ lastBody map[string]any
+ lastHeaders http.Header
+ respondText string // 写到 Anthropic content[0].text 里(校验用)
+ status int
+}
+
+func (h *captureHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ h.lastHeaders = r.Header.Clone()
+ defer func() { _ = r.Body.Close() }()
+ var parsed map[string]any
+ _ = json.NewDecoder(r.Body).Decode(&parsed)
+ h.lastBody = parsed
+
+ if h.status == 0 {
+ h.status = 200
+ }
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(h.status)
+ // 构造 Anthropic 格式的响应:content[0].text = h.respondText
+ _ = json.NewEncoder(w).Encode(map[string]any{
+ "content": []map[string]any{
+ {"type": "text", "text": h.respondText},
+ },
+ })
+}
+
+func setupFakeAnthropic(t *testing.T, handler *captureHandler) string {
+ t.Helper()
+ swapMonitorHTTPClient(t)
+ srv := httptest.NewServer(handler)
+ t.Cleanup(srv.Close)
+ return srv.URL
+}
+
+func TestRunCheckForModel_OffMode_PreservesDefaultBody(t *testing.T) {
+ h := &captureHandler{respondText: "the answer is 42"}
+ endpoint := setupFakeAnthropic(t, h)
+
+ // 跑一次 off 模式(opts=nil),确认默认 body 行为未变
+ _ = runCheckForModel(context.Background(), MonitorProviderAnthropic, endpoint, "sk-fake", "claude-x", nil)
+
+ if h.lastBody["model"] != "claude-x" {
+ t.Errorf("default body should contain model=claude-x, got %v", h.lastBody["model"])
+ }
+ if _, ok := h.lastBody["messages"]; !ok {
+ t.Error("default body should contain messages")
+ }
+ if h.lastHeaders.Get("x-api-key") != "sk-fake" {
+ t.Errorf("expected adapter's x-api-key header, got %q", h.lastHeaders.Get("x-api-key"))
+ }
+}
+
+func TestRunCheckForModel_MergeMode_UserFieldsWinButDenyListProtects(t *testing.T) {
+ h := &captureHandler{respondText: "the answer is 42"}
+ endpoint := setupFakeAnthropic(t, h)
+
+ opts := &CheckOptions{
+ BodyOverrideMode: MonitorBodyOverrideModeMerge,
+ BodyOverride: map[string]any{
+ "system": "You are Claude Code...",
+ "max_tokens": float64(999), // 应该覆盖默认 50
+ "model": "hacked-model", // 应该被黑名单挡住,保留原 model
+ "messages": []any{}, // 同上,被挡
+ },
+ ExtraHeaders: map[string]string{
+ "User-Agent": "claude-cli/1.0",
+ "Content-Length": "999", // 黑名单
+ "x-custom": "ok",
+ },
+ }
+ _ = runCheckForModel(context.Background(), MonitorProviderAnthropic, endpoint, "sk-fake", "claude-x", opts)
+
+ if h.lastBody["system"] != "You are Claude Code..." {
+ t.Errorf("merge mode should inject system, got %v", h.lastBody["system"])
+ }
+ // max_tokens 覆盖生效
+ if mt, ok := h.lastBody["max_tokens"].(float64); !ok || mt != 999 {
+ t.Errorf("merge mode should override max_tokens to 999, got %v", h.lastBody["max_tokens"])
+ }
+ // model 在黑名单 — 应该保留默认值
+ if h.lastBody["model"] != "claude-x" {
+ t.Errorf("model should be protected by deny list, got %v", h.lastBody["model"])
+ }
+ // messages 在黑名单 — 应该保留默认值(非空)
+ msgs, _ := h.lastBody["messages"].([]any)
+ if len(msgs) == 0 {
+ t.Error("messages should be protected by deny list (kept default, non-empty)")
+ }
+ // header 合并
+ if h.lastHeaders.Get("User-Agent") != "claude-cli/1.0" {
+ t.Errorf("extra User-Agent should override, got %q", h.lastHeaders.Get("User-Agent"))
+ }
+ if h.lastHeaders.Get("x-custom") != "ok" {
+ t.Errorf("extra custom header should be present, got %q", h.lastHeaders.Get("x-custom"))
+ }
+ // Content-Length 黑名单:会被 net/http 自动重算,但不应由用户的 "999" 决定。
+ // 我们无法直接断言丢弃(http.Client 总会填上),只断言请求成功即可。
+}
+
+func TestRunCheckForModel_ReplaceMode_FullBodyUsedAndChallengeSkipped(t *testing.T) {
+ // replace 模式下我们的 body 完全自定义,challenge 数学题不会出现在请求里,
+ // 上游也不会回正确答案 — 但只要 2xx + 响应文本非空,就算 operational
+ h := &captureHandler{respondText: "any non-empty text"}
+ endpoint := setupFakeAnthropic(t, h)
+
+ userBody := map[string]any{
+ "model": "user-forced-model",
+ "messages": []any{map[string]any{"role": "user", "content": "hi"}},
+ "max_tokens": float64(10),
+ "system": "You are someone else",
+ }
+ opts := &CheckOptions{
+ BodyOverrideMode: MonitorBodyOverrideModeReplace,
+ BodyOverride: userBody,
+ }
+ res := runCheckForModel(context.Background(), MonitorProviderAnthropic, endpoint, "sk-fake", "claude-x", opts)
+
+ // 请求 body = 用户提供的原样
+ if h.lastBody["model"] != "user-forced-model" {
+ t.Errorf("replace mode should use user's model, got %v", h.lastBody["model"])
+ }
+ if h.lastBody["system"] != "You are someone else" {
+ t.Errorf("replace mode should use user's system, got %v", h.lastBody["system"])
+ }
+ // challenge 虽然没命中,但由于 replace 模式跳过 challenge 校验 + 响应非空 → operational
+ if res.Status != MonitorStatusOperational {
+ t.Errorf("replace mode with 2xx + non-empty text should be operational, got status=%s message=%q",
+ res.Status, res.Message)
+ }
+}
+
+func TestRunCheckForModel_ReplaceMode_EmptyResponseIsFailed(t *testing.T) {
+ h := &captureHandler{respondText: ""} // 上游 200 但 content[0].text 为空
+ endpoint := setupFakeAnthropic(t, h)
+
+ opts := &CheckOptions{
+ BodyOverrideMode: MonitorBodyOverrideModeReplace,
+ BodyOverride: map[string]any{"model": "x", "messages": []any{}},
+ }
+ res := runCheckForModel(context.Background(), MonitorProviderAnthropic, endpoint, "sk-fake", "claude-x", opts)
+
+ if res.Status != MonitorStatusFailed {
+ t.Errorf("replace mode with empty text should be failed, got status=%s", res.Status)
+ }
+ if !strings.Contains(res.Message, "replace-mode") {
+ t.Errorf("failure message should hint replace-mode, got %q", res.Message)
+ }
+}
diff --git a/backend/internal/service/channel_monitor_service.go b/backend/internal/service/channel_monitor_service.go
index 144c66a0..ec1107a3 100644
--- a/backend/internal/service/channel_monitor_service.go
+++ b/backend/internal/service/channel_monitor_service.go
@@ -104,21 +104,31 @@ func (s *ChannelMonitorService) Create(ctx context.Context, p ChannelMonitorCrea
if err := validateCreateParams(p); err != nil {
return nil, err
}
+ if err := validateBodyModeParams(p.BodyOverrideMode, p.BodyOverride); err != nil {
+ return nil, err
+ }
+ if err := validateExtraHeaders(p.ExtraHeaders); err != nil {
+ return nil, err
+ }
encrypted, err := s.encryptor.Encrypt(p.APIKey)
if err != nil {
return nil, fmt.Errorf("encrypt api key: %w", err)
}
m := &ChannelMonitor{
- Name: strings.TrimSpace(p.Name),
- Provider: p.Provider,
- Endpoint: normalizeEndpoint(p.Endpoint),
- APIKey: encrypted, // 注意:传入 repository 时该字段为密文
- PrimaryModel: strings.TrimSpace(p.PrimaryModel),
- ExtraModels: normalizeModels(p.ExtraModels),
- GroupName: strings.TrimSpace(p.GroupName),
- Enabled: p.Enabled,
- IntervalSeconds: p.IntervalSeconds,
- CreatedBy: p.CreatedBy,
+ Name: strings.TrimSpace(p.Name),
+ Provider: p.Provider,
+ Endpoint: normalizeEndpoint(p.Endpoint),
+ APIKey: encrypted, // 注意:传入 repository 时该字段为密文
+ PrimaryModel: strings.TrimSpace(p.PrimaryModel),
+ ExtraModels: normalizeModels(p.ExtraModels),
+ GroupName: strings.TrimSpace(p.GroupName),
+ Enabled: p.Enabled,
+ IntervalSeconds: p.IntervalSeconds,
+ CreatedBy: p.CreatedBy,
+ TemplateID: p.TemplateID,
+ ExtraHeaders: emptyHeadersIfNil(p.ExtraHeaders),
+ BodyOverrideMode: defaultBodyMode(p.BodyOverrideMode),
+ BodyOverride: p.BodyOverride,
}
if err := s.repo.Create(ctx, m); err != nil {
return nil, fmt.Errorf("create channel monitor: %w", err)
@@ -272,12 +282,19 @@ func (s *ChannelMonitorService) runChecksConcurrent(ctx context.Context, m *Chan
// ping 共享一次,所有模型记录同一个 ping 延迟。
pingMs := pingEndpointOrigin(ctx, m.Endpoint)
+ // 所有模型共用同一份 CheckOptions(来自监控的快照字段)。
+ opts := &CheckOptions{
+ ExtraHeaders: m.ExtraHeaders,
+ BodyOverrideMode: m.BodyOverrideMode,
+ BodyOverride: m.BodyOverride,
+ }
+
var eg errgroup.Group
var mu sync.Mutex
for i, model := range models {
i, model := i, model
eg.Go(func() error {
- r := runCheckForModel(ctx, m.Provider, m.Endpoint, m.APIKey, model)
+ r := runCheckForModel(ctx, m.Provider, m.Endpoint, m.APIKey, model, opts)
r.PingLatencyMs = pingMs
mu.Lock()
results[i] = r
@@ -476,5 +493,38 @@ func applyMonitorUpdate(existing *ChannelMonitor, p ChannelMonitorUpdateParams)
}
existing.IntervalSeconds = *p.IntervalSeconds
}
+ return applyMonitorAdvancedUpdate(existing, p)
+}
+
+// applyMonitorAdvancedUpdate 处理自定义请求快照相关字段,从 applyMonitorUpdate 拆出避免过长。
+func applyMonitorAdvancedUpdate(existing *ChannelMonitor, p ChannelMonitorUpdateParams) error {
+ if p.ClearTemplate {
+ existing.TemplateID = nil
+ } else if p.TemplateID != nil {
+ id := *p.TemplateID
+ existing.TemplateID = &id
+ }
+ if p.ExtraHeaders != nil {
+ if err := validateExtraHeaders(*p.ExtraHeaders); err != nil {
+ return err
+ }
+ existing.ExtraHeaders = emptyHeadersIfNil(*p.ExtraHeaders)
+ }
+ // BodyOverrideMode / BodyOverride 联合校验,和模板一致。
+ newMode := existing.BodyOverrideMode
+ newBody := existing.BodyOverride
+ if p.BodyOverrideMode != nil {
+ newMode = *p.BodyOverrideMode
+ }
+ if p.BodyOverride != nil {
+ newBody = *p.BodyOverride
+ }
+ if p.BodyOverrideMode != nil || p.BodyOverride != nil {
+ if err := validateBodyModeParams(newMode, newBody); err != nil {
+ return err
+ }
+ existing.BodyOverrideMode = defaultBodyMode(newMode)
+ existing.BodyOverride = newBody
+ }
return nil
}
diff --git a/backend/internal/service/channel_monitor_template_service.go b/backend/internal/service/channel_monitor_template_service.go
new file mode 100644
index 00000000..98fc930b
--- /dev/null
+++ b/backend/internal/service/channel_monitor_template_service.go
@@ -0,0 +1,225 @@
+package service
+
+import (
+ "context"
+ "fmt"
+ "regexp"
+ "strings"
+)
+
+// ChannelMonitorRequestTemplateRepository 模板数据访问接口。
+type ChannelMonitorRequestTemplateRepository interface {
+ Create(ctx context.Context, t *ChannelMonitorRequestTemplate) error
+ GetByID(ctx context.Context, id int64) (*ChannelMonitorRequestTemplate, error)
+ Update(ctx context.Context, t *ChannelMonitorRequestTemplate) error
+ Delete(ctx context.Context, id int64) error
+ List(ctx context.Context, params ChannelMonitorRequestTemplateListParams) ([]*ChannelMonitorRequestTemplate, error)
+ // ApplyToMonitors 把模板当前的 extra_headers / body_override_mode / body_override
+ // 批量覆盖到所有 template_id = id 的监控上。返回被覆盖的监控数量。
+ ApplyToMonitors(ctx context.Context, id int64) (int64, error)
+ // CountAssociatedMonitors 统计 template_id = id 的监控数(用于 UI 展示「应用到 N 个配置」)。
+ CountAssociatedMonitors(ctx context.Context, id int64) (int64, error)
+}
+
+// ChannelMonitorRequestTemplateService 模板管理 service。
+type ChannelMonitorRequestTemplateService struct {
+ repo ChannelMonitorRequestTemplateRepository
+}
+
+// NewChannelMonitorRequestTemplateService 创建模板 service。
+func NewChannelMonitorRequestTemplateService(repo ChannelMonitorRequestTemplateRepository) *ChannelMonitorRequestTemplateService {
+ return &ChannelMonitorRequestTemplateService{repo: repo}
+}
+
+// ---------- CRUD ----------
+
+// List 按 provider 过滤(空串 = 全部),不分页(模板量级小)。
+func (s *ChannelMonitorRequestTemplateService) List(ctx context.Context, params ChannelMonitorRequestTemplateListParams) ([]*ChannelMonitorRequestTemplate, error) {
+ if params.Provider != "" {
+ if err := validateProvider(params.Provider); err != nil {
+ return nil, err
+ }
+ }
+ return s.repo.List(ctx, params)
+}
+
+// Get 返回单个模板。
+func (s *ChannelMonitorRequestTemplateService) Get(ctx context.Context, id int64) (*ChannelMonitorRequestTemplate, error) {
+ return s.repo.GetByID(ctx, id)
+}
+
+// Create 创建模板(会校验 headers 黑名单和 body 模式匹配)。
+func (s *ChannelMonitorRequestTemplateService) Create(ctx context.Context, p ChannelMonitorRequestTemplateCreateParams) (*ChannelMonitorRequestTemplate, error) {
+ if err := validateTemplateCreateParams(p); err != nil {
+ return nil, err
+ }
+ t := &ChannelMonitorRequestTemplate{
+ Name: strings.TrimSpace(p.Name),
+ Provider: p.Provider,
+ Description: strings.TrimSpace(p.Description),
+ ExtraHeaders: emptyHeadersIfNil(p.ExtraHeaders),
+ BodyOverrideMode: defaultBodyMode(p.BodyOverrideMode),
+ BodyOverride: p.BodyOverride,
+ }
+ if err := s.repo.Create(ctx, t); err != nil {
+ return nil, fmt.Errorf("create template: %w", err)
+ }
+ return t, nil
+}
+
+// Update 更新模板(provider 不可改)。
+func (s *ChannelMonitorRequestTemplateService) Update(ctx context.Context, id int64, p ChannelMonitorRequestTemplateUpdateParams) (*ChannelMonitorRequestTemplate, error) {
+ existing, err := s.repo.GetByID(ctx, id)
+ if err != nil {
+ return nil, err
+ }
+ if err := applyTemplateUpdate(existing, p); err != nil {
+ return nil, err
+ }
+ if err := s.repo.Update(ctx, existing); err != nil {
+ return nil, fmt.Errorf("update template: %w", err)
+ }
+ return existing, nil
+}
+
+// Delete 删除模板。关联监控的 template_id 会被 SET NULL,监控保留快照继续跑。
+func (s *ChannelMonitorRequestTemplateService) Delete(ctx context.Context, id int64) error {
+ if err := s.repo.Delete(ctx, id); err != nil {
+ return fmt.Errorf("delete template: %w", err)
+ }
+ return nil
+}
+
+// ApplyToMonitors 把模板当前配置一键应用到所有关联监控。
+// 返回被影响的监控数。
+func (s *ChannelMonitorRequestTemplateService) ApplyToMonitors(ctx context.Context, id int64) (int64, error) {
+ if _, err := s.repo.GetByID(ctx, id); err != nil {
+ return 0, err
+ }
+ affected, err := s.repo.ApplyToMonitors(ctx, id)
+ if err != nil {
+ return 0, fmt.Errorf("apply template to monitors: %w", err)
+ }
+ return affected, nil
+}
+
+// CountAssociatedMonitors 返回关联监控数。
+func (s *ChannelMonitorRequestTemplateService) CountAssociatedMonitors(ctx context.Context, id int64) (int64, error) {
+ return s.repo.CountAssociatedMonitors(ctx, id)
+}
+
+// ---------- 校验 & 工具 ----------
+
+// validateTemplateCreateParams 聚合 create 入参校验,避免函数超过 30 行。
+func validateTemplateCreateParams(p ChannelMonitorRequestTemplateCreateParams) error {
+ if strings.TrimSpace(p.Name) == "" {
+ return ErrChannelMonitorTemplateMissingName
+ }
+ if err := validateProvider(p.Provider); err != nil {
+ return ErrChannelMonitorTemplateInvalidProvider
+ }
+ if err := validateBodyModeParams(p.BodyOverrideMode, p.BodyOverride); err != nil {
+ return err
+ }
+ if err := validateExtraHeaders(p.ExtraHeaders); err != nil {
+ return err
+ }
+ return nil
+}
+
+// applyTemplateUpdate 把 update params 中非 nil 字段应用到 existing 上。
+func applyTemplateUpdate(existing *ChannelMonitorRequestTemplate, p ChannelMonitorRequestTemplateUpdateParams) error {
+ if p.Name != nil {
+ name := strings.TrimSpace(*p.Name)
+ if name == "" {
+ return ErrChannelMonitorTemplateMissingName
+ }
+ existing.Name = name
+ }
+ if p.Description != nil {
+ existing.Description = strings.TrimSpace(*p.Description)
+ }
+ if p.ExtraHeaders != nil {
+ if err := validateExtraHeaders(*p.ExtraHeaders); err != nil {
+ return err
+ }
+ existing.ExtraHeaders = emptyHeadersIfNil(*p.ExtraHeaders)
+ }
+ // BodyOverrideMode / BodyOverride 联合校验:任一变化都用「更新后的值」做校验。
+ newMode := existing.BodyOverrideMode
+ newBody := existing.BodyOverride
+ if p.BodyOverrideMode != nil {
+ newMode = *p.BodyOverrideMode
+ }
+ if p.BodyOverride != nil {
+ newBody = *p.BodyOverride
+ }
+ if err := validateBodyModeParams(newMode, newBody); err != nil {
+ return err
+ }
+ existing.BodyOverrideMode = defaultBodyMode(newMode)
+ existing.BodyOverride = newBody
+ return nil
+}
+
+// validateBodyModeParams 校验 body_override_mode 合法,且 merge/replace 模式下 body_override 非空。
+func validateBodyModeParams(mode string, body map[string]any) error {
+ switch mode {
+ case "", MonitorBodyOverrideModeOff:
+ return nil
+ case MonitorBodyOverrideModeMerge, MonitorBodyOverrideModeReplace:
+ if len(body) == 0 {
+ return ErrChannelMonitorTemplateBodyRequired
+ }
+ return nil
+ default:
+ return ErrChannelMonitorTemplateInvalidBodyMode
+ }
+}
+
+// headerNameRegex 合法 header 名:RFC 7230 token(ASCII 可见字符减特殊符号)。
+var headerNameRegex = regexp.MustCompile(`^[A-Za-z0-9!#$%&'*+\-.^_` + "`" + `|~]+$`)
+
+// forbiddenHeaderNames hop-by-hop + HTTP 客户端自管的 header;禁止用户覆盖,
+// 否则会让 Go http.Client 行为异常(双重 Content-Length、连接复用错乱等)。
+var forbiddenHeaderNames = map[string]bool{
+ "host": true,
+ "content-length": true,
+ "content-encoding": true,
+ "transfer-encoding": true,
+ "connection": true,
+}
+
+// IsForbiddenHeaderName 对外暴露,checker 运行时也会再过滤一次做兜底。
+func IsForbiddenHeaderName(name string) bool {
+ return forbiddenHeaderNames[strings.ToLower(strings.TrimSpace(name))]
+}
+
+// validateExtraHeaders 校验 header 名字格式 + 黑名单。保存时就拒绝非法 header,早失败。
+func validateExtraHeaders(h map[string]string) error {
+ for k := range h {
+ if !headerNameRegex.MatchString(k) {
+ return ErrChannelMonitorTemplateHeaderInvalidName
+ }
+ if IsForbiddenHeaderName(k) {
+ return ErrChannelMonitorTemplateHeaderForbidden
+ }
+ }
+ return nil
+}
+
+// emptyHeadersIfNil 把 nil map 归一成空 map(repo 层写库时 JSONB 需要非 nil)。
+func emptyHeadersIfNil(h map[string]string) map[string]string {
+ if h == nil {
+ return map[string]string{}
+ }
+ return h
+}
+
+// defaultBodyMode 空串归一为 off。
+func defaultBodyMode(mode string) string {
+ if mode == "" {
+ return MonitorBodyOverrideModeOff
+ }
+ return mode
+}
diff --git a/backend/internal/service/channel_monitor_template_types.go b/backend/internal/service/channel_monitor_template_types.go
new file mode 100644
index 00000000..a6e2bb59
--- /dev/null
+++ b/backend/internal/service/channel_monitor_template_types.go
@@ -0,0 +1,74 @@
+package service
+
+import (
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "time"
+)
+
+// ChannelMonitorRequestTemplate 请求模板(service 层模型)。
+// 作用:把一组可复用的 headers + 可选 body 覆盖配置抽出来管理,
+// 被监控「应用」时以快照方式拷贝到监控本身的同名字段。
+type ChannelMonitorRequestTemplate struct {
+ ID int64
+ Name string
+ Provider string
+ Description string
+ ExtraHeaders map[string]string
+ BodyOverrideMode string
+ BodyOverride map[string]any
+ CreatedAt time.Time
+ UpdatedAt time.Time
+}
+
+// ChannelMonitorRequestTemplateListParams 列表过滤。
+type ChannelMonitorRequestTemplateListParams struct {
+ Provider string // 空 = 全部;非空则按 provider 过滤
+}
+
+// ChannelMonitorRequestTemplateCreateParams 创建参数。
+type ChannelMonitorRequestTemplateCreateParams struct {
+ Name string
+ Provider string
+ Description string
+ ExtraHeaders map[string]string
+ BodyOverrideMode string
+ BodyOverride map[string]any
+}
+
+// ChannelMonitorRequestTemplateUpdateParams 更新参数(指针字段 = 不修改)。
+// 注意 Provider 不可修改:改 provider 会让已关联监控的 body 黑名单语义错乱。
+type ChannelMonitorRequestTemplateUpdateParams struct {
+ Name *string
+ Description *string
+ ExtraHeaders *map[string]string
+ BodyOverrideMode *string
+ BodyOverride *map[string]any
+}
+
+// 模板相关错误(命名与现有 ErrChannelMonitor* 风格保持一致)。
+var (
+ ErrChannelMonitorTemplateNotFound = infraerrors.NotFound(
+ "CHANNEL_MONITOR_TEMPLATE_NOT_FOUND", "channel monitor request template not found",
+ )
+ ErrChannelMonitorTemplateInvalidProvider = infraerrors.BadRequest(
+ "CHANNEL_MONITOR_TEMPLATE_INVALID_PROVIDER", "template provider must be one of openai/anthropic/gemini",
+ )
+ ErrChannelMonitorTemplateMissingName = infraerrors.BadRequest(
+ "CHANNEL_MONITOR_TEMPLATE_MISSING_NAME", "template name is required",
+ )
+ ErrChannelMonitorTemplateInvalidBodyMode = infraerrors.BadRequest(
+ "CHANNEL_MONITOR_TEMPLATE_INVALID_BODY_MODE", "body_override_mode must be one of off/merge/replace",
+ )
+ ErrChannelMonitorTemplateBodyRequired = infraerrors.BadRequest(
+ "CHANNEL_MONITOR_TEMPLATE_BODY_REQUIRED", "body_override is required when body_override_mode is merge or replace",
+ )
+ ErrChannelMonitorTemplateHeaderForbidden = infraerrors.BadRequest(
+ "CHANNEL_MONITOR_TEMPLATE_HEADER_FORBIDDEN", "header name is forbidden (hop-by-hop or computed by HTTP client)",
+ )
+ ErrChannelMonitorTemplateHeaderInvalidName = infraerrors.BadRequest(
+ "CHANNEL_MONITOR_TEMPLATE_HEADER_INVALID_NAME", "header name contains invalid characters",
+ )
+ ErrChannelMonitorTemplateProviderMismatch = infraerrors.BadRequest(
+ "CHANNEL_MONITOR_TEMPLATE_PROVIDER_MISMATCH", "monitor provider does not match template provider",
+ )
+)
diff --git a/backend/internal/service/channel_monitor_types.go b/backend/internal/service/channel_monitor_types.go
index 739c82fb..b797a89b 100644
--- a/backend/internal/service/channel_monitor_types.go
+++ b/backend/internal/service/channel_monitor_types.go
@@ -2,6 +2,19 @@ package service
import "time"
+// MonitorBodyOverrideMode 自定义请求体处理模式。
+//
+// - off 使用 adapter 默认 body(忽略 BodyOverride)
+// - merge adapter 默认 body 与 BodyOverride 浅合并(用户优先;
+// model/messages/contents 等关键字段在 checker 黑名单内会被静默丢弃)
+// - replace 完全用 BodyOverride 作为 body;跳过 challenge 校验,
+// 改成 HTTP 2xx + 响应非空即视为可用(用户负责构造 body)
+const (
+ MonitorBodyOverrideModeOff = "off"
+ MonitorBodyOverrideModeMerge = "merge"
+ MonitorBodyOverrideModeReplace = "replace"
+)
+
// ChannelMonitor 渠道监控配置(service 层模型,不直接暴露 ent 类型)。
type ChannelMonitor struct {
ID int64
@@ -19,6 +32,12 @@ type ChannelMonitor struct {
CreatedAt time.Time
UpdatedAt time.Time
+ // 请求自定义快照(来自模板拷贝 or 用户手填,运行时直接读取)
+ TemplateID *int64 // 仅用于 UI 分组 + 一键应用,运行时不用
+ ExtraHeaders map[string]string // 与 adapter 默认 headers 合并,用户优先
+ BodyOverrideMode string // off / merge / replace
+ BodyOverride map[string]any // 仅 mode != off 时使用
+
// APIKeyDecryptFailed 表示 APIKey 字段无法解密(密钥不一致或损坏)。
// 此时 APIKey 为空字符串,runner / RunCheck 必须跳过该监控并提示重填。
APIKeyDecryptFailed bool
@@ -35,16 +54,20 @@ type ChannelMonitorListParams struct {
// ChannelMonitorCreateParams 创建参数。
type ChannelMonitorCreateParams struct {
- Name string
- Provider string
- Endpoint string
- APIKey string
- PrimaryModel string
- ExtraModels []string
- GroupName string
- Enabled bool
- IntervalSeconds int
- CreatedBy int64
+ Name string
+ Provider string
+ Endpoint string
+ APIKey string
+ PrimaryModel string
+ ExtraModels []string
+ GroupName string
+ Enabled bool
+ IntervalSeconds int
+ CreatedBy int64
+ TemplateID *int64
+ ExtraHeaders map[string]string
+ BodyOverrideMode string
+ BodyOverride map[string]any
}
// ChannelMonitorUpdateParams 更新参数(指针字段表示"未提供则不更新")。
@@ -58,6 +81,14 @@ type ChannelMonitorUpdateParams struct {
GroupName *string
Enabled *bool
IntervalSeconds *int
+ // 自定义快照字段:指针为 nil 表示不更新,非 nil 覆盖
+ // TemplateID *(*int64):用 ** 表达三态:nil=不更新;&nil=清空;&&id=设为 id。
+ // 简化处理:用 ClearTemplate 显式标志 + TemplateID(普通指针)
+ TemplateID *int64
+ ClearTemplate bool // true 时无视 TemplateID,把监控的 template_id 置空
+ ExtraHeaders *map[string]string
+ BodyOverrideMode *string
+ BodyOverride *map[string]any
}
// CheckResult 单个模型一次检测的结果。
diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go
index 1482d650..3148f865 100644
--- a/backend/internal/service/wire.go
+++ b/backend/internal/service/wire.go
@@ -472,6 +472,7 @@ var ProviderSet = wire.NewSet(
ProvideBalanceNotifyService,
ProvideChannelMonitorService,
ProvideChannelMonitorRunner,
+ NewChannelMonitorRequestTemplateService,
)
// ProvidePaymentConfigService wraps NewPaymentConfigService to accept the named
diff --git a/backend/migrations/128_add_channel_monitor_request_templates.sql b/backend/migrations/128_add_channel_monitor_request_templates.sql
new file mode 100644
index 00000000..2db8fef6
--- /dev/null
+++ b/backend/migrations/128_add_channel_monitor_request_templates.sql
@@ -0,0 +1,70 @@
+-- Migration: 128_add_channel_monitor_request_templates
+-- 加请求模板表 + 给 channel_monitors 加 4 个快照字段(template_id 关联引用 + extra_headers /
+-- body_override_mode / body_override 三个真正运行时使用的快照)。
+--
+-- 设计要点:
+-- 1) 模板与监控之间是「应用即拷贝」的快照语义,运行时 checker 不再回查模板表。
+-- 模板 UPDATE 不会自动影响监控;只有用户主动「应用到关联监控」才会刷新快照。
+-- 2) ON DELETE SET NULL:模板删除不级联清理监控;监控保留快照继续工作。
+-- 3) extra_headers / body_override 都是 JSONB;body_override_mode 用 varchar(不是 enum)
+-- 便于将来加新模式无需 ALTER TYPE。
+-- 4) 同一 provider 内模板 name 唯一(允许 Anthropic + OpenAI 重名 "伪装官方客户端")。
+
+CREATE TABLE IF NOT EXISTS channel_monitor_request_templates (
+ id BIGSERIAL PRIMARY KEY,
+ name VARCHAR(100) NOT NULL,
+ provider VARCHAR(20) NOT NULL,
+ description VARCHAR(500) NOT NULL DEFAULT '',
+ extra_headers JSONB NOT NULL DEFAULT '{}'::jsonb,
+ body_override_mode VARCHAR(10) NOT NULL DEFAULT 'off',
+ body_override JSONB NULL,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ CONSTRAINT channel_monitor_request_templates_provider_check
+ CHECK (provider IN ('openai', 'anthropic', 'gemini')),
+ CONSTRAINT channel_monitor_request_templates_body_mode_check
+ CHECK (body_override_mode IN ('off', 'merge', 'replace'))
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS channel_monitor_request_templates_provider_name
+ ON channel_monitor_request_templates (provider, name);
+
+-- channel_monitors 加 4 列(ADD COLUMN IF NOT EXISTS 需要 PG 9.6+,生产使用 PG 16)
+ALTER TABLE channel_monitors
+ ADD COLUMN IF NOT EXISTS template_id BIGINT NULL;
+ALTER TABLE channel_monitors
+ ADD COLUMN IF NOT EXISTS extra_headers JSONB NOT NULL DEFAULT '{}'::jsonb;
+ALTER TABLE channel_monitors
+ ADD COLUMN IF NOT EXISTS body_override_mode VARCHAR(10) NOT NULL DEFAULT 'off';
+ALTER TABLE channel_monitors
+ ADD COLUMN IF NOT EXISTS body_override JSONB NULL;
+
+-- 约束 + 外键(DO 块里 IF NOT EXISTS 判断,保证幂等)
+DO $$
+BEGIN
+ IF NOT EXISTS (
+ SELECT 1 FROM information_schema.table_constraints
+ WHERE constraint_name = 'channel_monitors_body_mode_check'
+ AND table_name = 'channel_monitors'
+ ) THEN
+ ALTER TABLE channel_monitors
+ ADD CONSTRAINT channel_monitors_body_mode_check
+ CHECK (body_override_mode IN ('off', 'merge', 'replace'));
+ END IF;
+
+ IF NOT EXISTS (
+ SELECT 1 FROM information_schema.table_constraints
+ WHERE constraint_name = 'channel_monitors_template_id_fkey'
+ AND table_name = 'channel_monitors'
+ ) THEN
+ ALTER TABLE channel_monitors
+ ADD CONSTRAINT channel_monitors_template_id_fkey
+ FOREIGN KEY (template_id)
+ REFERENCES channel_monitor_request_templates (id)
+ ON DELETE SET NULL;
+ END IF;
+END $$;
+
+CREATE INDEX IF NOT EXISTS idx_channel_monitors_template_id
+ ON channel_monitors (template_id)
+ WHERE template_id IS NOT NULL;
diff --git a/frontend/src/api/admin/channelMonitor.ts b/frontend/src/api/admin/channelMonitor.ts
index d9cc6aed..949c4bc8 100644
--- a/frontend/src/api/admin/channelMonitor.ts
+++ b/frontend/src/api/admin/channelMonitor.ts
@@ -7,6 +7,7 @@ import { apiClient } from '../client'
export type Provider = 'openai' | 'anthropic' | 'gemini'
export type MonitorStatus = 'operational' | 'degraded' | 'failed' | 'error'
+export type BodyOverrideMode = 'off' | 'merge' | 'replace'
export interface ChannelMonitor {
id: number
@@ -37,6 +38,11 @@ export interface ChannelMonitor {
availability_7d: number
/** Latest status per extra model (used for hover tooltip) */
extra_models_status: ExtraModelStatus[]
+ /** 请求自定义快照字段(高级设置) */
+ template_id: number | null
+ extra_headers: Record
+ body_override_mode: BodyOverrideMode
+ body_override: Record | null
}
export interface ExtraModelStatus {
@@ -71,10 +77,16 @@ export interface CreateParams {
group_name?: string
enabled?: boolean
interval_seconds: number
+ template_id?: number | null
+ extra_headers?: Record
+ body_override_mode?: BodyOverrideMode
+ body_override?: Record | null
}
-// Update request: api_key empty string means "do not modify"
-export type UpdateParams = Partial
+// Update request: api_key 空串 = 不修改;clear_template=true 时把 template_id 置空
+export type UpdateParams = Partial & {
+ clear_template?: boolean
+}
export interface CheckResult {
model: string
diff --git a/frontend/src/api/admin/channelMonitorTemplate.ts b/frontend/src/api/admin/channelMonitorTemplate.ts
new file mode 100644
index 00000000..258adab8
--- /dev/null
+++ b/frontend/src/api/admin/channelMonitorTemplate.ts
@@ -0,0 +1,108 @@
+/**
+ * Admin Channel Monitor Request Template API.
+ *
+ * 模板 = 一组可复用的 headers + 可选 body 覆盖配置。
+ * 应用到监控 = 拷贝快照;模板后续变动不自动同步,需手动点「应用到关联监控」刷新。
+ */
+
+import { apiClient } from '../client'
+import type { BodyOverrideMode, Provider } from './channelMonitor'
+
+export interface ChannelMonitorTemplate {
+ id: number
+ name: string
+ provider: Provider
+ description: string
+ extra_headers: Record
+ body_override_mode: BodyOverrideMode
+ body_override: Record | null
+ created_at: string
+ updated_at: string
+ /** 关联的监控数量(快照来自此模板,仅 template_id 匹配即可) */
+ associated_monitors: number
+}
+
+export interface ListParams {
+ provider?: Provider
+}
+
+export interface ListResponse {
+ items: ChannelMonitorTemplate[]
+}
+
+export interface CreateParams {
+ name: string
+ provider: Provider
+ description?: string
+ extra_headers?: Record
+ body_override_mode?: BodyOverrideMode
+ body_override?: Record | null
+}
+
+export interface UpdateParams {
+ name?: string
+ description?: string
+ extra_headers?: Record
+ body_override_mode?: BodyOverrideMode
+ body_override?: Record | null
+}
+
+export interface ApplyResponse {
+ affected: number
+}
+
+export async function list(params: ListParams = {}): Promise {
+ const { data } = await apiClient.get('/admin/channel-monitor-templates', {
+ params,
+ })
+ return data
+}
+
+export async function get(id: number): Promise {
+ const { data } = await apiClient.get(
+ `/admin/channel-monitor-templates/${id}`,
+ )
+ return data
+}
+
+export async function create(params: CreateParams): Promise {
+ const { data } = await apiClient.post(
+ '/admin/channel-monitor-templates',
+ params,
+ )
+ return data
+}
+
+export async function update(id: number, params: UpdateParams): Promise {
+ const { data } = await apiClient.put(
+ `/admin/channel-monitor-templates/${id}`,
+ params,
+ )
+ return data
+}
+
+export async function del(id: number): Promise {
+ await apiClient.delete(`/admin/channel-monitor-templates/${id}`)
+}
+
+/**
+ * Apply the template to all associated monitors (overwrite snapshot fields).
+ * Returns count of affected monitors.
+ */
+export async function apply(id: number): Promise {
+ const { data } = await apiClient.post(
+ `/admin/channel-monitor-templates/${id}/apply`,
+ )
+ return data
+}
+
+export const channelMonitorTemplateAPI = {
+ list,
+ get,
+ create,
+ update,
+ del,
+ apply,
+}
+
+export default channelMonitorTemplateAPI
diff --git a/frontend/src/api/admin/index.ts b/frontend/src/api/admin/index.ts
index 5e2a9959..9cda5814 100644
--- a/frontend/src/api/admin/index.ts
+++ b/frontend/src/api/admin/index.ts
@@ -27,6 +27,7 @@ import backupAPI from './backup'
import tlsFingerprintProfileAPI from './tlsFingerprintProfile'
import channelsAPI from './channels'
import channelMonitorAPI from './channelMonitor'
+import channelMonitorTemplateAPI from './channelMonitorTemplate'
import adminPaymentAPI from './payment'
/**
@@ -57,6 +58,7 @@ export const adminAPI = {
tlsFingerprintProfiles: tlsFingerprintProfileAPI,
channels: channelsAPI,
channelMonitor: channelMonitorAPI,
+ channelMonitorTemplate: channelMonitorTemplateAPI,
payment: adminPaymentAPI
}
@@ -85,6 +87,7 @@ export {
tlsFingerprintProfileAPI,
channelsAPI,
channelMonitorAPI,
+ channelMonitorTemplateAPI,
adminPaymentAPI
}
diff --git a/frontend/src/components/admin/monitor/MonitorAdvancedRequestConfig.vue b/frontend/src/components/admin/monitor/MonitorAdvancedRequestConfig.vue
new file mode 100644
index 00000000..24827316
--- /dev/null
+++ b/frontend/src/components/admin/monitor/MonitorAdvancedRequestConfig.vue
@@ -0,0 +1,205 @@
+
+
+
+
+
{{ t('admin.channelMonitor.advanced.headers') }}
+
+
{{ headersError }}
+
+ {{ t('admin.channelMonitor.advanced.headersHint') }}
+
+
+
+
+
+
{{ t('admin.channelMonitor.advanced.bodyMode') }}
+
+
+ {{ opt.label }}
+
+
+
+ {{ bodyModeHint }}
+
+
+
+
+
+
{{ t('admin.channelMonitor.advanced.bodyJson') }}
+
+
{{ bodyError }}
+
+ {{ t('admin.channelMonitor.advanced.bodyJsonHint') }}
+
+
+
+
+
+
diff --git a/frontend/src/components/admin/monitor/MonitorFiltersBar.vue b/frontend/src/components/admin/monitor/MonitorFiltersBar.vue
index ebb06a68..eb2a5c78 100644
--- a/frontend/src/components/admin/monitor/MonitorFiltersBar.vue
+++ b/frontend/src/components/admin/monitor/MonitorFiltersBar.vue
@@ -44,6 +44,14 @@
>
+
+
+ {{ t('admin.channelMonitor.template.manageButton') }}
+
{{ t('admin.channelMonitor.createButton') }}
@@ -71,6 +79,7 @@ defineProps<{
defineEmits<{
(e: 'reload'): void
(e: 'create'): void
+ (e: 'manage-templates'): void
(e: 'search-input'): void
}>()
diff --git a/frontend/src/components/admin/monitor/MonitorFormDialog.vue b/frontend/src/components/admin/monitor/MonitorFormDialog.vue
index 4a538fcf..21fa4715 100644
--- a/frontend/src/components/admin/monitor/MonitorFormDialog.vue
+++ b/frontend/src/components/admin/monitor/MonitorFormDialog.vue
@@ -95,6 +95,35 @@
{{ t('admin.channelMonitor.form.enabled') }}
+
+
+
+
+ {{ t('admin.channelMonitor.advanced.section') }}
+
+ {{ t('admin.channelMonitor.advanced.sectionHint') }}
+
+
+
+
{{ t('admin.channelMonitor.templateField.label') }}
+
+
{{ t('admin.channelMonitor.templateField.applyHint') }}
+
+
+
+
+
@@ -136,17 +165,21 @@ import { adminAPI } from '@/api/admin'
import { keysAPI } from '@/api/keys'
import { userGroupsAPI } from '@/api/groups'
import type {
+ BodyOverrideMode,
ChannelMonitor,
CreateParams,
Provider,
UpdateParams,
} from '@/api/admin/channelMonitor'
+import type { ChannelMonitorTemplate } from '@/api/admin/channelMonitorTemplate'
import type { ApiKey } from '@/types'
import BaseDialog from '@/components/common/BaseDialog.vue'
import Toggle from '@/components/common/Toggle.vue'
+import Select from '@/components/common/Select.vue'
import ModelTagInput from '@/components/admin/channel/ModelTagInput.vue'
import { getPlatformTextClass } from '@/components/admin/channel/types'
import MonitorKeyPickerDialog from '@/components/admin/monitor/MonitorKeyPickerDialog.vue'
+import MonitorAdvancedRequestConfig from '@/components/admin/monitor/MonitorAdvancedRequestConfig.vue'
import ProviderIcon from '@/components/user/monitor/ProviderIcon.vue'
import { useChannelMonitorFormat } from '@/composables/useChannelMonitorFormat'
import {
@@ -198,11 +231,16 @@ interface MonitorForm {
group_name: string
interval_seconds: number
enabled: boolean
+ // 高级设置快照
+ template_id: number | null
+ extra_headers: Record
+ body_override_mode: BodyOverrideMode
+ body_override: Record | null
}
const form = reactive({
name: '',
- provider: PROVIDER_OPENAI,
+ provider: PROVIDER_ANTHROPIC,
endpoint: '',
api_key: '',
primary_model: '',
@@ -210,6 +248,57 @@ const form = reactive({
group_name: '',
interval_seconds: systemDefaultInterval.value,
enabled: true,
+ template_id: null,
+ extra_headers: {},
+ body_override_mode: 'off',
+ body_override: null,
+})
+
+// 可用模板列表(进入 dialog 时一次性拉取 cache;按 provider 过滤)。
+const templatesCache = ref([])
+const templatesLoading = ref(false)
+
+const templateOptions = computed(() => {
+ const items = templatesCache.value.filter((t) => t.provider === form.provider)
+ return [
+ { value: '', label: t('admin.channelMonitor.templateField.none') },
+ ...items.map((t) => ({ value: String(t.id), label: t.name })),
+ ]
+})
+
+async function loadTemplates() {
+ if (templatesCache.value.length > 0) return
+ templatesLoading.value = true
+ try {
+ const { items } = await adminAPI.channelMonitorTemplate.list()
+ templatesCache.value = items
+ } catch (err: unknown) {
+ // 模板拉取失败不阻塞监控表单,用户可以不选模板
+ console.warn('load monitor templates failed', err)
+ } finally {
+ templatesLoading.value = false
+ }
+}
+
+// 模板下拉绑定:value 是 string(Select 组件约束),需要与 number | null 互转。
+const templateSelectValue = computed({
+ get: () => (form.template_id == null ? '' : String(form.template_id)),
+ set: (raw: string) => {
+ if (raw === '') {
+ form.template_id = null
+ return
+ }
+ const id = Number(raw)
+ if (!Number.isFinite(id)) return
+ form.template_id = id
+ // 应用模板 = 拷贝快照
+ const tpl = templatesCache.value.find((t) => t.id === id)
+ if (tpl) {
+ form.extra_headers = { ...(tpl.extra_headers || {}) }
+ form.body_override_mode = tpl.body_override_mode
+ form.body_override = tpl.body_override ? { ...tpl.body_override } : null
+ }
+ },
})
interface ProviderOption {
@@ -218,8 +307,8 @@ interface ProviderOption {
}
const providerOptions = computed(() => [
- { value: PROVIDER_OPENAI, label: t('monitorCommon.providers.openai') },
{ value: PROVIDER_ANTHROPIC, label: t('monitorCommon.providers.anthropic') },
+ { value: PROVIDER_OPENAI, label: t('monitorCommon.providers.openai') },
{ value: PROVIDER_GEMINI, label: t('monitorCommon.providers.gemini') },
])
@@ -227,13 +316,15 @@ const providerOptions = computed(() => [
// Editing mode loads api_key='' via loadFromMonitor and only sets it on user
// typing, so clearing on provider change is always a safe no-op until the user
// picks a new key.
+// 同时清空 template_id(模板有 provider 归属,跨平台不通用)。
watch(() => form.provider, () => {
form.api_key = ''
+ form.template_id = null
})
function resetForm() {
form.name = ''
- form.provider = PROVIDER_OPENAI
+ form.provider = PROVIDER_ANTHROPIC
form.endpoint = ''
form.api_key = ''
form.primary_model = ''
@@ -241,6 +332,10 @@ function resetForm() {
form.group_name = ''
form.interval_seconds = systemDefaultInterval.value
form.enabled = true
+ form.template_id = null
+ form.extra_headers = {}
+ form.body_override_mode = 'off'
+ form.body_override = null
}
function loadFromMonitor(m: ChannelMonitor) {
@@ -253,13 +348,19 @@ function loadFromMonitor(m: ChannelMonitor) {
form.group_name = m.group_name || ''
form.interval_seconds = m.interval_seconds || systemDefaultInterval.value
form.enabled = m.enabled
+ form.template_id = m.template_id ?? null
+ form.extra_headers = { ...(m.extra_headers || {}) }
+ form.body_override_mode = m.body_override_mode || 'off'
+ form.body_override = m.body_override ? { ...m.body_override } : null
}
// Re-sync form whenever the dialog is opened or the target monitor changes.
+// 同时拉取模板列表(cache 过的话一次性返回)。
watch(
() => [props.show, props.monitor] as const,
([show, m]) => {
if (!show) return
+ void loadTemplates()
if (m) loadFromMonitor(m)
else resetForm()
},
@@ -310,6 +411,10 @@ function buildPayload(): CreateParams {
group_name: form.group_name.trim(),
enabled: form.enabled,
interval_seconds: form.interval_seconds,
+ template_id: form.template_id,
+ extra_headers: form.extra_headers,
+ body_override_mode: form.body_override_mode,
+ body_override: form.body_override,
}
}
@@ -329,9 +434,14 @@ async function handleSubmit() {
const target = editing.value
if (target) {
const { api_key, ...rest } = buildPayload()
- const req: UpdateParams = rest
+ const req: UpdateParams = { ...rest }
// Only send api_key if user typed a new value
if (api_key) req.api_key = api_key
+ // template_id=null 用 clear_template=true 明确告诉后端清空(pointer 语义)
+ if (form.template_id == null) {
+ req.clear_template = true
+ delete req.template_id
+ }
await adminAPI.channelMonitor.update(target.id, req)
appStore.showSuccess(t('admin.channelMonitor.updateSuccess'))
} else {
diff --git a/frontend/src/components/admin/monitor/MonitorTemplateManagerDialog.vue b/frontend/src/components/admin/monitor/MonitorTemplateManagerDialog.vue
new file mode 100644
index 00000000..992a402e
--- /dev/null
+++ b/frontend/src/components/admin/monitor/MonitorTemplateManagerDialog.vue
@@ -0,0 +1,465 @@
+
+
+
+
+
+
+ {{ tab.label }}
+
+ {{ countByProvider[tab.value] }}
+
+
+
+
+
+
+
+
+
+
+ {{ t('admin.channelMonitor.template.createButton') }}
+
+
+
+
+ {{ t('common.loading') }}
+
+
+
+ {{ t('admin.channelMonitor.template.emptyState') }}
+
+
+
+
+
+
+ {{ tpl.name }}
+
+ {{ modeLabel(tpl.body_override_mode) }}
+
+
+ {{ t('admin.channelMonitor.template.associatedCount', { n: tpl.associated_monitors }) }}
+
+
+
+ {{ tpl.description }}
+
+
+ {{ t('admin.channelMonitor.template.headersSummary', {
+ n: Object.keys(tpl.extra_headers || {}).length,
+ }) }}
+
+
+
+
+
+ {{ t('admin.channelMonitor.template.applyButton') }}
+
+
+ {{ t('common.edit') }}
+
+
+ {{ t('common.delete') }}
+
+
+
+
+
+
+
+
+
+
+ {{ t('admin.channelMonitor.template.form.name') }}
+ *
+
+
+
+
+
+
+ {{ t('admin.channelMonitor.form.provider') }}
+ *
+
+
+
+ {{ opt.label }}
+
+
+
+
+
+
+ {{ t('admin.channelMonitor.template.form.description') }}
+
+
+
+
+
+
+
+
+
+
+
+
+ {{ t('common.back') }}
+
+
+
+
+
+ {{ t('common.close') }}
+
+
+ {{ submitting ? t('common.submitting') : editing === 'new' ? t('common.create') : t('common.update') }}
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/frontend/src/components/user/monitor/MonitorHero.vue b/frontend/src/components/user/monitor/MonitorHero.vue
index 7fc4d846..e978e66c 100644
--- a/frontend/src/components/user/monitor/MonitorHero.vue
+++ b/frontend/src/components/user/monitor/MonitorHero.vue
@@ -1,52 +1,49 @@
-
-
-
-
+
+
+
-
- {{ opt.label }}
-
-
+ {{ opt.label }}
+
+
+
+
+
+ {{ overallLabel }}
+
-
-
-
- {{ overallLabel }}
-
-
-
-
-
+
+
+
-
- {{ updatedLabel }} · {{ t('monitorCommon.pollEvery', { n: intervalSeconds }) }}
-
+
+ {{ updatedLabel }} · {{ t('monitorCommon.pollEvery', { n: intervalSeconds }) }}
diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts
index f8d5befa..be99cb7c 100644
--- a/frontend/src/i18n/locales/en.ts
+++ b/frontend/src/i18n/locales/en.ts
@@ -2156,7 +2156,57 @@ export default {
},
runResultTitle: 'Check Result',
noMonitorsYet: 'No monitors yet',
- createFirstMonitor: 'Create your first monitor to track channel availability'
+ createFirstMonitor: 'Create your first monitor to track channel availability',
+ advanced: {
+ section: 'Advanced (optional)',
+ sectionHint: 'Customize request headers and body to bypass upstream client-detection (e.g. "only Claude Code clients allowed").',
+ headers: 'Custom request headers',
+ headersPlaceholder: 'User-Agent: claude-cli/1.0.83 (external, cli)\nx-app: cli\nanthropic-beta: claude-code-20250219',
+ headersHint: 'One Key: Value per line; merged on top of adapter defaults (user wins). Hop-by-hop headers (Host / Content-Length / ...) are ignored.',
+ headersParseError: 'Cannot parse line: {line}',
+ bodyMode: 'Body handling',
+ bodyModeOff: 'Default',
+ bodyModeMerge: 'Merge',
+ bodyModeReplace: 'Replace',
+ bodyModeHintOff: 'Use the adapter default body (includes challenge validation).',
+ bodyModeHintMerge: 'Shallow-merge with the default body; user fields win but model / messages / contents are protected (use Replace to change those).',
+ bodyModeHintReplace: 'Use the JSON below as the complete body. Challenge validation is skipped; HTTP 2xx + non-empty response text is treated as operational.',
+ bodyJson: 'Body JSON',
+ bodyJsonHint: 'Parsed on blur. Empty means no override.',
+ bodyJsonError: 'JSON parse failed',
+ bodyJsonObjectError: 'Body must be a JSON object (no arrays or primitives)'
+ },
+ templateField: {
+ label: 'Request template',
+ none: 'No template',
+ placeholder: 'Pick a template (filtered by current provider)',
+ applyHint: 'Picking a template copies its headers and body to this monitor (snapshot). Later template edits are not auto-synced.'
+ },
+ template: {
+ manageButton: 'Templates',
+ managerTitle: 'Request template manager',
+ createButton: 'New template',
+ emptyState: 'No templates for this provider yet',
+ missingName: 'Template name is required',
+ createSuccess: 'Template created',
+ updateSuccess: 'Template updated',
+ deleteSuccess: 'Template deleted',
+ applyButton: 'Apply to monitors',
+ applyTooltip: 'Overwrite snapshot fields on all associated monitors',
+ applyTitle: 'Apply template',
+ applyConfirm: 'Apply',
+ applyConfirmMessage: 'Overwrite {n} associated monitor(s) with the current configuration of "{name}"? Any local customizations on those monitors will be discarded.',
+ applySuccess: 'Applied to {n} monitor(s)',
+ deleteConfirm: 'Delete template "{name}"? {n} associated monitor(s) will be disassociated but keep their current snapshot and continue running.',
+ associatedCount: '{n} associated monitor(s)',
+ headersSummary: '{n} custom header(s)',
+ form: {
+ name: 'Template name',
+ namePlaceholder: 'e.g. Claude Code mimicry',
+ description: 'Description',
+ descriptionPlaceholder: 'Optional: what this template is for, capture date, etc.'
+ }
+ }
},
// Subscriptions
diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts
index eea34140..a3ce8716 100644
--- a/frontend/src/i18n/locales/zh.ts
+++ b/frontend/src/i18n/locales/zh.ts
@@ -2235,7 +2235,57 @@ export default {
},
runResultTitle: '检测结果',
noMonitorsYet: '暂无监控',
- createFirstMonitor: '创建第一个监控来跟踪渠道可用性'
+ createFirstMonitor: '创建第一个监控来跟踪渠道可用性',
+ advanced: {
+ section: '高级(可选)',
+ sectionHint: '自定义请求头和请求体,用于突破上游的客户端识别限制(如仅允许 Claude Code 客户端)。',
+ headers: '自定义请求头',
+ headersPlaceholder: 'User-Agent: claude-cli/1.0.83 (external, cli)\nx-app: cli\nanthropic-beta: claude-code-20250219',
+ headersHint: '每行一对 Key: Value;会与默认请求头合并,用户值优先。hop-by-hop 类 header(Host/Content-Length/...)会被忽略。',
+ headersParseError: '无法解析这一行:{line}',
+ bodyMode: '请求体处理',
+ bodyModeOff: '默认',
+ bodyModeMerge: '合并',
+ bodyModeReplace: '覆盖',
+ bodyModeHintOff: '使用 adapter 默认请求体(带 challenge 数学题校验)。',
+ bodyModeHintMerge: '与默认请求体浅合并,用户字段优先;但 model / messages / contents 会被保护不允许覆盖(动这些字段请用「覆盖」模式)。',
+ bodyModeHintReplace: '完全用下方 JSON 作为请求体。注意:此模式下跳过 challenge 校验,改为 HTTP 2xx + 响应文本非空即视为可用。',
+ bodyJson: 'Body JSON',
+ bodyJsonHint: '失焦时自动解析校验。留空等价于没有覆盖。',
+ bodyJsonError: 'JSON 解析失败',
+ bodyJsonObjectError: '请求体必须是一个 JSON 对象(不能是数组或基本类型)'
+ },
+ templateField: {
+ label: '请求模板',
+ none: '不使用模板',
+ placeholder: '选择一个模板(按当前平台过滤)',
+ applyHint: '选中模板后,会把模板的请求头和请求体拷贝到此监控(快照)。后续模板变动不自动同步。'
+ },
+ template: {
+ manageButton: '模板管理',
+ managerTitle: '请求模板管理',
+ createButton: '新建模板',
+ emptyState: '当前平台下还没有请求模板',
+ missingName: '请输入模板名称',
+ createSuccess: '模板创建成功',
+ updateSuccess: '模板更新成功',
+ deleteSuccess: '模板删除成功',
+ applyButton: '应用到关联监控',
+ applyTooltip: '把当前模板配置覆盖到所有关联的监控上',
+ applyTitle: '应用模板',
+ applyConfirm: '确认应用',
+ applyConfirmMessage: '将把模板「{name}」的当前配置覆盖到 {n} 个关联监控。监控本地已编辑的自定义修改会被丢弃,是否继续?',
+ applySuccess: '已应用到 {n} 个监控',
+ deleteConfirm: '确定要删除模板「{name}」吗?{n} 个关联监控会解除关联但保留自己的快照继续工作。',
+ associatedCount: '{n} 个关联监控',
+ headersSummary: '{n} 个自定义请求头',
+ form: {
+ name: '模板名称',
+ namePlaceholder: '例:Claude Code 伪装',
+ description: '说明',
+ descriptionPlaceholder: '可选:说明这个模板的用途和来源(抓包日期等)'
+ }
+ }
},
// Subscriptions Management
diff --git a/frontend/src/views/admin/ChannelMonitorView.vue b/frontend/src/views/admin/ChannelMonitorView.vue
index 8f0a1e2f..fab19f26 100644
--- a/frontend/src/views/admin/ChannelMonitorView.vue
+++ b/frontend/src/views/admin/ChannelMonitorView.vue
@@ -9,6 +9,7 @@
:loading="loading"
@reload="reload"
@create="openCreateDialog"
+ @manage-templates="showTemplateManager = true"
@search-input="handleSearch"
/>
@@ -86,6 +87,12 @@
@saved="reload"
/>
+
+
('')
const pagination = reactive({ page: 1, page_size: getPersistedPageSize(), total: 0 })
const showDialog = ref(false)
+const showTemplateManager = ref(false)
const editing = ref(null)
const showDeleteDialog = ref(false)
const deleting = ref(null)
--
GitLab
From 6925ac25c4d8a9e9af77403c54529a5f06e78c91 Mon Sep 17 00:00:00 2001
From: erio
Date: Tue, 21 Apr 2026 14:39:19 +0800
Subject: [PATCH 093/261] feat(channel-monitor): apply template via subset
picker; CC 2.1.114 baseline doc
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Apply flow:
- POST /admin/channel-monitor-templates/:id/apply now requires monitor_ids
(non-empty array). Service applies the template only to the selected
subset, gated by AND template_id = :id (so users can't sneak in
unrelated monitor IDs).
- New GET /admin/channel-monitor-templates/:id/monitors returns the
associated monitor briefs (id/name/provider/enabled) for the picker.
- ApplyToMonitors signature gains monitorIDs []int64; empty list returns
ErrChannelMonitorTemplateApplyEmpty.
Frontend:
- New MonitorTemplateApplyPickerDialog.vue: list of associated monitors
with checkboxes (default all checked), 全选 / 全不选 shortcuts, live
selected/total count. Submit calls apply(id, ids).
- MonitorTemplateManagerDialog replaces the old ConfirmDialog flow with
the picker; onApplied refetches the list to refresh associated counts.
i18n: applyPicker* + common.selectAll keys.
chore: bump version to 0.1.114.33
The CC 2.1.114 (sdk-cli) UA / APIKeyBetaHeader / JSON metadata.user_id
baseline (already verified working via the in-process apply on prod
template id=1) is documented in internal/pkg/claude/constants.go and
is what the seed template in the manager UI should follow.
---
.../admin/channel_monitor_template_handler.go | 43 ++++-
.../channel_monitor_template_repo.go | 39 +++-
backend/internal/server/routes/admin.go | 1 +
.../channel_monitor_template_service.go | 38 +++-
.../service/channel_monitor_template_types.go | 3 +
.../src/api/admin/channelMonitorTemplate.ts | 30 ++-
.../MonitorTemplateApplyPickerDialog.vue | 174 ++++++++++++++++++
.../monitor/MonitorTemplateManagerDialog.vue | 48 ++---
frontend/src/i18n/locales/en.ts | 9 +-
frontend/src/i18n/locales/zh.ts | 7 +
10 files changed, 341 insertions(+), 51 deletions(-)
create mode 100644 frontend/src/components/admin/monitor/MonitorTemplateApplyPickerDialog.vue
diff --git a/backend/internal/handler/admin/channel_monitor_template_handler.go b/backend/internal/handler/admin/channel_monitor_template_handler.go
index 8c1191ea..bebe0929 100644
--- a/backend/internal/handler/admin/channel_monitor_template_handler.go
+++ b/backend/internal/handler/admin/channel_monitor_template_handler.go
@@ -179,17 +179,56 @@ func (h *ChannelMonitorRequestTemplateHandler) Delete(c *gin.Context) {
response.Success(c, nil)
}
+type channelMonitorTemplateApplyRequest struct {
+ // MonitorIDs 必填、非空:用户在 picker 里勾选的要被覆盖的监控 ID 列表。
+ // 仅当对应监控当前 template_id == :id 时才会真的被覆盖。
+ MonitorIDs []int64 `json:"monitor_ids" binding:"required,min=1"`
+}
+
// Apply POST /api/v1/admin/channel-monitor-templates/:id/apply
-// 一键把模板当前配置覆盖到所有关联监控上。
+// 把模板当前配置覆盖到 monitor_ids 列表里的关联监控(picker 选中的子集)。
func (h *ChannelMonitorRequestTemplateHandler) Apply(c *gin.Context) {
id, ok := parseTemplateID(c)
if !ok {
return
}
- affected, err := h.templateService.ApplyToMonitors(c.Request.Context(), id)
+ var req channelMonitorTemplateApplyRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
+ return
+ }
+ affected, err := h.templateService.ApplyToMonitors(c.Request.Context(), id, req.MonitorIDs)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"affected": affected})
}
+
+type associatedMonitorBriefResponse struct {
+ ID int64 `json:"id"`
+ Name string `json:"name"`
+ Provider string `json:"provider"`
+ Enabled bool `json:"enabled"`
+}
+
+// AssociatedMonitors GET /api/v1/admin/channel-monitor-templates/:id/monitors
+// 列出关联监控(picker 弹窗用)。
+func (h *ChannelMonitorRequestTemplateHandler) AssociatedMonitors(c *gin.Context) {
+ id, ok := parseTemplateID(c)
+ if !ok {
+ return
+ }
+ items, err := h.templateService.ListAssociatedMonitors(c.Request.Context(), id)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ out := make([]associatedMonitorBriefResponse, 0, len(items))
+ for _, m := range items {
+ out = append(out, associatedMonitorBriefResponse{
+ ID: m.ID, Name: m.Name, Provider: m.Provider, Enabled: m.Enabled,
+ })
+ }
+ response.Success(c, gin.H{"items": out})
+}
diff --git a/backend/internal/repository/channel_monitor_template_repo.go b/backend/internal/repository/channel_monitor_template_repo.go
index 03f3692b..845d186b 100644
--- a/backend/internal/repository/channel_monitor_template_repo.go
+++ b/backend/internal/repository/channel_monitor_template_repo.go
@@ -103,11 +103,13 @@ func (r *channelMonitorRequestTemplateRepository) List(ctx context.Context, para
return out, nil
}
-// ApplyToMonitors 把模板当前配置批量覆盖到 template_id = id 的监控上。
-//
-// 用一条 UPDATE 完成:extra_headers / body_override_mode / body_override 都覆盖。
-// 走 ent 的 UpdateMany 保证走 ent hooks;走原生 SQL 也可以但 ent jsonb 序列化更省心。
-func (r *channelMonitorRequestTemplateRepository) ApplyToMonitors(ctx context.Context, id int64) (int64, error) {
+// ApplyToMonitors 把模板当前配置覆盖到 monitorIDs 列表里的关联监控。
+// WHERE 双重过滤:template_id = id AND id IN (monitorIDs),防止用户传了未关联本模板的 id
+// 就被覆盖。走 ent UpdateMany 保留 hooks。
+func (r *channelMonitorRequestTemplateRepository) ApplyToMonitors(ctx context.Context, id int64, monitorIDs []int64) (int64, error) {
+ if len(monitorIDs) == 0 {
+ return 0, nil
+ }
client := clientFromContext(ctx, r.client)
tpl, err := client.ChannelMonitorRequestTemplate.Query().
Where(channelmonitorrequesttemplate.IDEQ(id)).
@@ -117,7 +119,10 @@ func (r *channelMonitorRequestTemplateRepository) ApplyToMonitors(ctx context.Co
}
updater := client.ChannelMonitor.Update().
- Where(channelmonitor.TemplateIDEQ(id)).
+ Where(
+ channelmonitor.TemplateIDEQ(id),
+ channelmonitor.IDIn(monitorIDs...),
+ ).
SetExtraHeaders(emptyHeadersIfNilRepo(tpl.ExtraHeaders)).
SetBodyOverrideMode(defaultBodyModeRepo(tpl.BodyOverrideMode))
if tpl.BodyOverride != nil {
@@ -144,6 +149,28 @@ func (r *channelMonitorRequestTemplateRepository) CountAssociatedMonitors(ctx co
return int64(count), nil
}
+// ListAssociatedMonitors 列出模板关联的所有监控简略字段。
+// ORDER BY name 稳定输出方便前端展示。
+func (r *channelMonitorRequestTemplateRepository) ListAssociatedMonitors(ctx context.Context, id int64) ([]*service.AssociatedMonitorBrief, error) {
+ rows, err := r.client.ChannelMonitor.Query().
+ Where(channelmonitor.TemplateIDEQ(id)).
+ Order(dbent.Asc(channelmonitor.FieldName)).
+ All(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("list associated monitors for template %d: %w", id, err)
+ }
+ out := make([]*service.AssociatedMonitorBrief, 0, len(rows))
+ for _, row := range rows {
+ out = append(out, &service.AssociatedMonitorBrief{
+ ID: row.ID,
+ Name: row.Name,
+ Provider: string(row.Provider),
+ Enabled: row.Enabled,
+ })
+ }
+ return out, nil
+}
+
// ---------- helpers ----------
func entToServiceTemplate(row *dbent.ChannelMonitorRequestTemplate) *service.ChannelMonitorRequestTemplate {
diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go
index 13cecd59..4b796d55 100644
--- a/backend/internal/server/routes/admin.go
+++ b/backend/internal/server/routes/admin.go
@@ -587,6 +587,7 @@ func registerChannelMonitorRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
templates.GET("/:id", h.Admin.ChannelMonitorTemplate.Get)
templates.PUT("/:id", h.Admin.ChannelMonitorTemplate.Update)
templates.DELETE("/:id", h.Admin.ChannelMonitorTemplate.Delete)
+ templates.GET("/:id/monitors", h.Admin.ChannelMonitorTemplate.AssociatedMonitors)
templates.POST("/:id/apply", h.Admin.ChannelMonitorTemplate.Apply)
}
}
diff --git a/backend/internal/service/channel_monitor_template_service.go b/backend/internal/service/channel_monitor_template_service.go
index 98fc930b..8d2e8173 100644
--- a/backend/internal/service/channel_monitor_template_service.go
+++ b/backend/internal/service/channel_monitor_template_service.go
@@ -15,10 +15,23 @@ type ChannelMonitorRequestTemplateRepository interface {
Delete(ctx context.Context, id int64) error
List(ctx context.Context, params ChannelMonitorRequestTemplateListParams) ([]*ChannelMonitorRequestTemplate, error)
// ApplyToMonitors 把模板当前的 extra_headers / body_override_mode / body_override
- // 批量覆盖到所有 template_id = id 的监控上。返回被覆盖的监控数量。
- ApplyToMonitors(ctx context.Context, id int64) (int64, error)
+ // 批量覆盖到指定 monitorIDs 的监控上(同时还要求这些监控当前 template_id = id,
+ // 防止误覆盖未关联的监控)。monitorIDs 必须非空;空列表直接返回 0 不写库。
+ // 返回被覆盖的监控数量。
+ ApplyToMonitors(ctx context.Context, id int64, monitorIDs []int64) (int64, error)
// CountAssociatedMonitors 统计 template_id = id 的监控数(用于 UI 展示「应用到 N 个配置」)。
CountAssociatedMonitors(ctx context.Context, id int64) (int64, error)
+ // ListAssociatedMonitors 列出所有 template_id = id 的监控简略信息(id/name/provider/enabled)
+ // 给 apply picker UI 用,避免前端再做一次 list+filter。
+ ListAssociatedMonitors(ctx context.Context, id int64) ([]*AssociatedMonitorBrief, error)
+}
+
+// AssociatedMonitorBrief 模板关联监控的简略信息(picker / 列表展示用)。
+type AssociatedMonitorBrief struct {
+ ID int64
+ Name string
+ Provider string
+ Enabled bool
}
// ChannelMonitorRequestTemplateService 模板管理 service。
@@ -90,13 +103,17 @@ func (s *ChannelMonitorRequestTemplateService) Delete(ctx context.Context, id in
return nil
}
-// ApplyToMonitors 把模板当前配置一键应用到所有关联监控。
-// 返回被影响的监控数。
-func (s *ChannelMonitorRequestTemplateService) ApplyToMonitors(ctx context.Context, id int64) (int64, error) {
+// ApplyToMonitors 把模板当前配置应用到 monitorIDs 列表里的关联监控。
+// monitorIDs 必须非空且每个 id 都必须当前 template_id = id;不满足条件的会被 SQL WHERE 过滤掉。
+// 返回实际被覆盖的监控数。
+func (s *ChannelMonitorRequestTemplateService) ApplyToMonitors(ctx context.Context, id int64, monitorIDs []int64) (int64, error) {
if _, err := s.repo.GetByID(ctx, id); err != nil {
return 0, err
}
- affected, err := s.repo.ApplyToMonitors(ctx, id)
+ if len(monitorIDs) == 0 {
+ return 0, ErrChannelMonitorTemplateApplyEmpty
+ }
+ affected, err := s.repo.ApplyToMonitors(ctx, id, monitorIDs)
if err != nil {
return 0, fmt.Errorf("apply template to monitors: %w", err)
}
@@ -108,6 +125,15 @@ func (s *ChannelMonitorRequestTemplateService) CountAssociatedMonitors(ctx conte
return s.repo.CountAssociatedMonitors(ctx, id)
}
+// ListAssociatedMonitors 返回模板关联的所有监控简略信息。
+// 给前端 apply picker 用,handler 直接吐 JSON 不再做 join。
+func (s *ChannelMonitorRequestTemplateService) ListAssociatedMonitors(ctx context.Context, id int64) ([]*AssociatedMonitorBrief, error) {
+ if _, err := s.repo.GetByID(ctx, id); err != nil {
+ return nil, err
+ }
+ return s.repo.ListAssociatedMonitors(ctx, id)
+}
+
// ---------- 校验 & 工具 ----------
// validateTemplateCreateParams 聚合 create 入参校验,避免函数超过 30 行。
diff --git a/backend/internal/service/channel_monitor_template_types.go b/backend/internal/service/channel_monitor_template_types.go
index a6e2bb59..e5bf7568 100644
--- a/backend/internal/service/channel_monitor_template_types.go
+++ b/backend/internal/service/channel_monitor_template_types.go
@@ -71,4 +71,7 @@ var (
ErrChannelMonitorTemplateProviderMismatch = infraerrors.BadRequest(
"CHANNEL_MONITOR_TEMPLATE_PROVIDER_MISMATCH", "monitor provider does not match template provider",
)
+ ErrChannelMonitorTemplateApplyEmpty = infraerrors.BadRequest(
+ "CHANNEL_MONITOR_TEMPLATE_APPLY_EMPTY", "monitor_ids must be a non-empty array",
+ )
)
diff --git a/frontend/src/api/admin/channelMonitorTemplate.ts b/frontend/src/api/admin/channelMonitorTemplate.ts
index 258adab8..01b3c2d0 100644
--- a/frontend/src/api/admin/channelMonitorTemplate.ts
+++ b/frontend/src/api/admin/channelMonitorTemplate.ts
@@ -51,6 +51,17 @@ export interface ApplyResponse {
affected: number
}
+export interface AssociatedMonitorBrief {
+ id: number
+ name: string
+ provider: Provider
+ enabled: boolean
+}
+
+export interface AssociatedMonitorsResponse {
+ items: AssociatedMonitorBrief[]
+}
+
export async function list(params: ListParams = {}): Promise {
const { data } = await apiClient.get('/admin/channel-monitor-templates', {
params,
@@ -86,12 +97,24 @@ export async function del(id: number): Promise {
}
/**
- * Apply the template to all associated monitors (overwrite snapshot fields).
- * Returns count of affected monitors.
+ * Apply the template to the specified associated monitors (overwrite snapshot fields).
+ * monitorIds must be a non-empty subset of the template's associated monitors.
+ * Returns count of actually affected monitors.
*/
-export async function apply(id: number): Promise {
+export async function apply(id: number, monitorIds: number[]): Promise {
const { data } = await apiClient.post(
`/admin/channel-monitor-templates/${id}/apply`,
+ { monitor_ids: monitorIds },
+ )
+ return data
+}
+
+/**
+ * List monitors currently associated to this template (used by apply picker).
+ */
+export async function listAssociatedMonitors(id: number): Promise {
+ const { data } = await apiClient.get(
+ `/admin/channel-monitor-templates/${id}/monitors`,
)
return data
}
@@ -103,6 +126,7 @@ export const channelMonitorTemplateAPI = {
update,
del,
apply,
+ listAssociatedMonitors,
}
export default channelMonitorTemplateAPI
diff --git a/frontend/src/components/admin/monitor/MonitorTemplateApplyPickerDialog.vue b/frontend/src/components/admin/monitor/MonitorTemplateApplyPickerDialog.vue
new file mode 100644
index 00000000..427b75ff
--- /dev/null
+++ b/frontend/src/components/admin/monitor/MonitorTemplateApplyPickerDialog.vue
@@ -0,0 +1,174 @@
+
+
+
+ {{ t('admin.channelMonitor.template.applyPickerHint') }}
+
+
+
+ {{ t('common.loading') }}
+
+
+
+ {{ t('admin.channelMonitor.template.applyPickerEmpty') }}
+
+
+
+
+
+
+ {{ t('common.selectAll') }}
+
+
+ {{ t('admin.channelMonitor.template.selectNone') }}
+
+
+ {{ t('admin.channelMonitor.template.selectedCount', {
+ n: selectedIds.length,
+ total: monitors.length,
+ }) }}
+
+
+
+
+
+
+
+
+
+ {{ t('common.cancel') }}
+
+
+ {{ submitting
+ ? t('common.submitting')
+ : t('admin.channelMonitor.template.applyPickerConfirm', { n: selectedIds.length }) }}
+
+
+
+
+
+
+
diff --git a/frontend/src/components/admin/monitor/MonitorTemplateManagerDialog.vue b/frontend/src/components/admin/monitor/MonitorTemplateManagerDialog.vue
index 992a402e..3a03f5bc 100644
--- a/frontend/src/components/admin/monitor/MonitorTemplateManagerDialog.vue
+++ b/frontend/src/components/admin/monitor/MonitorTemplateManagerDialog.vue
@@ -180,14 +180,12 @@
-
({
+// --- apply to monitors (picker 流程) ---
+const applyPicker = reactive<{ show: boolean; tpl: ChannelMonitorTemplate | null }>({
show: false,
tpl: null,
})
function confirmApply(tpl: ChannelMonitorTemplate) {
- confirmApply_.tpl = tpl
- confirmApply_.show = true
+ applyPicker.tpl = tpl
+ applyPicker.show = true
}
-const confirmApplyMessage = computed(() => {
- const tpl = confirmApply_.tpl
- if (!tpl) return ''
- return t('admin.channelMonitor.template.applyConfirmMessage', {
- name: tpl.name,
- n: tpl.associated_monitors,
- })
-})
-
-async function doApply() {
- const tpl = confirmApply_.tpl
- confirmApply_.show = false
- if (!tpl) return
- try {
- const { affected } = await adminAPI.channelMonitorTemplate.apply(tpl.id)
- appStore.showSuccess(t('admin.channelMonitor.template.applySuccess', { n: affected }))
- await fetchTemplates()
- emit('updated')
- } catch (err: unknown) {
- appStore.showError(extractApiErrorMessage(err, t('common.error')))
- }
+// picker 提交后触发:刷新模板列表(拿最新 associated_monitors)+ 通知父组件
+async function onApplied(_affected: number) {
+ await fetchTemplates()
+ emit('updated')
}
// --- delete ---
diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts
index be99cb7c..1d49efbf 100644
--- a/frontend/src/i18n/locales/en.ts
+++ b/frontend/src/i18n/locales/en.ts
@@ -273,6 +273,7 @@ export default {
no: 'No',
all: 'All',
none: 'None',
+ selectAll: 'Select all',
noData: 'No data',
expand: 'Expand',
collapse: 'Collapse',
@@ -2192,11 +2193,17 @@ export default {
updateSuccess: 'Template updated',
deleteSuccess: 'Template deleted',
applyButton: 'Apply to monitors',
- applyTooltip: 'Overwrite snapshot fields on all associated monitors',
+ applyTooltip: 'Overwrite snapshot fields on associated monitors',
applyTitle: 'Apply template',
applyConfirm: 'Apply',
applyConfirmMessage: 'Overwrite {n} associated monitor(s) with the current configuration of "{name}"? Any local customizations on those monitors will be discarded.',
applySuccess: 'Applied to {n} monitor(s)',
+ applyPickerTitle: 'Apply template "{name}"',
+ applyPickerHint: 'Select which monitors to overwrite (all selected by default). Any local customizations will be discarded.',
+ applyPickerEmpty: 'No monitors are currently associated to this template',
+ applyPickerConfirm: 'Apply to {n} monitor(s)',
+ selectNone: 'Select none',
+ selectedCount: 'Selected {n} / {total}',
deleteConfirm: 'Delete template "{name}"? {n} associated monitor(s) will be disassociated but keep their current snapshot and continue running.',
associatedCount: '{n} associated monitor(s)',
headersSummary: '{n} custom header(s)',
diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts
index a3ce8716..fb84dfd2 100644
--- a/frontend/src/i18n/locales/zh.ts
+++ b/frontend/src/i18n/locales/zh.ts
@@ -273,6 +273,7 @@ export default {
no: '否',
all: '全部',
none: '无',
+ selectAll: '全选',
noData: '暂无数据',
expand: '展开',
collapse: '收起',
@@ -2276,6 +2277,12 @@ export default {
applyConfirm: '确认应用',
applyConfirmMessage: '将把模板「{name}」的当前配置覆盖到 {n} 个关联监控。监控本地已编辑的自定义修改会被丢弃,是否继续?',
applySuccess: '已应用到 {n} 个监控',
+ applyPickerTitle: '应用模板「{name}」',
+ applyPickerHint: '勾选要覆盖请求头/请求体的监控(默认全选)。监控本地已编辑的自定义修改会被丢弃。',
+ applyPickerEmpty: '当前模板没有关联监控',
+ applyPickerConfirm: '应用到 {n} 个监控',
+ selectNone: '全不选',
+ selectedCount: '已选 {n} / {total}',
deleteConfirm: '确定要删除模板「{name}」吗?{n} 个关联监控会解除关联但保留自己的快照继续工作。',
associatedCount: '{n} 个关联监控',
headersSummary: '{n} 个自定义请求头',
--
GitLab
From a7415d4d2ef319429f8e8bd04ca874b73b031111 Mon Sep 17 00:00:00 2001
From: erio
Date: Tue, 21 Apr 2026 15:24:48 +0800
Subject: [PATCH 094/261] feat(monitor): 30-day raw retention + timeline 4-tier
style + CC template seed + JSON format button
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
- History retention 1d → 30d(60s × 30d ≈ 43200 行/model,PG 无压力);
ComputeAvailability* 不再 UNION rollup 表,直接扫 histories 精度更高。
- Timeline bar 四级高度+颜色双重编码:operational 高+绿 / degraded 中+黄 /
failed+error 短+红 / 未测试 很短+灰。
- migration 113 seed「Claude Code 伪装」模板(ON CONFLICT DO NOTHING)。
user_id 用 legacy 格式(user_<64hex>_account__session_),
避免新版 JSON 字符串内嵌 JSON 在编辑器里一长串 \" 难读。
- MonitorAdvancedRequestConfig 加「格式化」按钮 + white-space:pre
让 body textarea 对长字符串不压扁。
---
.../repository/channel_monitor_repo.go | 74 +++++--------------
.../internal/service/channel_monitor_const.go | 7 +-
.../129_seed_claude_code_template.sql | 38 ++++++++++
.../monitor/MonitorAdvancedRequestConfig.vue | 35 ++++++++-
.../user/monitor/MonitorTimeline.vue | 10 ++-
frontend/src/i18n/locales/en.ts | 1 +
frontend/src/i18n/locales/zh.ts | 1 +
7 files changed, 102 insertions(+), 64 deletions(-)
create mode 100644 backend/migrations/129_seed_claude_code_template.sql
diff --git a/backend/internal/repository/channel_monitor_repo.go b/backend/internal/repository/channel_monitor_repo.go
index 67dccd6c..800ee43b 100644
--- a/backend/internal/repository/channel_monitor_repo.go
+++ b/backend/internal/repository/channel_monitor_repo.go
@@ -297,41 +297,22 @@ func assignNullInt(dst **int, n sql.NullInt64) {
// "可用" = status IN (operational, degraded)。
//
// 数据来源:明细表只保留 1 天;窗口前其余天数走聚合表。
-// - raw = 今天(CURRENT_DATE 起)的未软删明细,按 model 累加
-// - rollup = [CURRENT_DATE - windowDays, CURRENT_DATE) 区间的聚合行
-//
-// 总窗口为 "今天 + 过去 windowDays 天",比 windowDays 字面值大 1 天,但因为聚合
-// 是按整 UTC 日切的,这是聚合化无法避免的精度损失,且偏宽不偏窄(数据更全)。
+// 明细保留 30 天(monitorHistoryRetentionDays),窗口 <= 30 天时直接扫 histories,
+// 精度到秒,避免与聚合表 UNION 带来的 UTC 日切精度损失。
func (r *channelMonitorRepository) ComputeAvailability(ctx context.Context, monitorID int64, windowDays int) ([]*service.ChannelMonitorAvailability, error) {
if windowDays <= 0 {
windowDays = 7
}
const q = `
- WITH raw AS (
- SELECT model,
- COUNT(*) AS total_checks,
- COUNT(*) FILTER (WHERE status IN ('operational','degraded')) AS ok_count,
- COALESCE(SUM(latency_ms) FILTER (WHERE latency_ms IS NOT NULL), 0) AS sum_latency_ms,
- COUNT(latency_ms) AS count_latency
- FROM channel_monitor_histories
- WHERE monitor_id = $1
- AND checked_at >= CURRENT_DATE
- GROUP BY model
- ),
- rollup AS (
- SELECT model, total_checks, ok_count, sum_latency_ms, count_latency
- FROM channel_monitor_daily_rollups
- WHERE monitor_id = $1
- AND bucket_date >= (CURRENT_DATE - $2::int)
- AND bucket_date < CURRENT_DATE
- )
SELECT model,
- SUM(total_checks) AS total,
- SUM(ok_count) AS ok,
- CASE WHEN SUM(count_latency) > 0
- THEN SUM(sum_latency_ms)::float8 / SUM(count_latency)
- ELSE NULL END AS avg_latency_ms
- FROM (SELECT * FROM raw UNION ALL SELECT * FROM rollup) combined
+ COUNT(*) AS total,
+ COUNT(*) FILTER (WHERE status IN ('operational','degraded')) AS ok,
+ CASE WHEN COUNT(latency_ms) > 0
+ THEN SUM(latency_ms) FILTER (WHERE latency_ms IS NOT NULL)::float8 / COUNT(latency_ms)
+ ELSE NULL END AS avg_latency_ms
+ FROM channel_monitor_histories
+ WHERE monitor_id = $1
+ AND checked_at >= NOW() - ($2::int || ' days')::interval
GROUP BY model
`
rows, err := r.db.QueryContext(ctx, q, monitorID, windowDays)
@@ -514,7 +495,7 @@ func clampTimelineLimit(n int) int {
}
// ComputeAvailabilityForMonitors 一次性计算多个监控在某个窗口内的每模型可用率与平均延迟。
-// 与单 monitor 版本同构:明细只覆盖今天,更早走聚合表 UNION 合并。
+// 明细保留 30 天,直接扫 histories(窗口 <= 30 天时无需聚合)。
func (r *channelMonitorRepository) ComputeAvailabilityForMonitors(ctx context.Context, ids []int64, windowDays int) (map[int64][]*service.ChannelMonitorAvailability, error) {
out := make(map[int64][]*service.ChannelMonitorAvailability, len(ids))
if len(ids) == 0 {
@@ -524,33 +505,16 @@ func (r *channelMonitorRepository) ComputeAvailabilityForMonitors(ctx context.Co
windowDays = 7
}
const q = `
- WITH raw AS (
- SELECT monitor_id,
- model,
- COUNT(*) AS total_checks,
- COUNT(*) FILTER (WHERE status IN ('operational','degraded')) AS ok_count,
- COALESCE(SUM(latency_ms) FILTER (WHERE latency_ms IS NOT NULL), 0) AS sum_latency_ms,
- COUNT(latency_ms) AS count_latency
- FROM channel_monitor_histories
- WHERE monitor_id = ANY($1)
- AND checked_at >= CURRENT_DATE
- GROUP BY monitor_id, model
- ),
- rollup AS (
- SELECT monitor_id, model, total_checks, ok_count, sum_latency_ms, count_latency
- FROM channel_monitor_daily_rollups
- WHERE monitor_id = ANY($1)
- AND bucket_date >= (CURRENT_DATE - $2::int)
- AND bucket_date < CURRENT_DATE
- )
SELECT monitor_id,
model,
- SUM(total_checks) AS total,
- SUM(ok_count) AS ok,
- CASE WHEN SUM(count_latency) > 0
- THEN SUM(sum_latency_ms)::float8 / SUM(count_latency)
- ELSE NULL END AS avg_latency_ms
- FROM (SELECT * FROM raw UNION ALL SELECT * FROM rollup) combined
+ COUNT(*) AS total,
+ COUNT(*) FILTER (WHERE status IN ('operational','degraded')) AS ok,
+ CASE WHEN COUNT(latency_ms) > 0
+ THEN SUM(latency_ms) FILTER (WHERE latency_ms IS NOT NULL)::float8 / COUNT(latency_ms)
+ ELSE NULL END AS avg_latency_ms
+ FROM channel_monitor_histories
+ WHERE monitor_id = ANY($1)
+ AND checked_at >= NOW() - ($2::int || ' days')::interval
GROUP BY monitor_id, model
`
rows, err := r.db.QueryContext(ctx, q, pq.Array(ids), windowDays)
diff --git a/backend/internal/service/channel_monitor_const.go b/backend/internal/service/channel_monitor_const.go
index 768a432f..2fc45639 100644
--- a/backend/internal/service/channel_monitor_const.go
+++ b/backend/internal/service/channel_monitor_const.go
@@ -16,9 +16,10 @@ const (
// monitorDegradedThreshold 主请求成功但耗时超过该阈值视为 degraded。
monitorDegradedThreshold = 6 * time.Second
// monitorHistoryRetentionDays 明细历史保留天数。
- // 明细只保留 1 天,超出由 SoftDeleteMixin 软删;
- // 维护任务每天凌晨跑(由 OpsCleanupService 统一调度)。
- monitorHistoryRetentionDays = 1
+ // 60s 默认间隔 * 30 天 ≈ 43200 行/monitor/model,一般部署总量 <= 2M 行,
+ // PG 无压力;所以直接保留完整明细一个月,可用率查询可以全走原始行不依赖聚合。
+ // 聚合表 channel_monitor_daily_rollups 仍然保留,作为长期历史回填/降级查询的兜底。
+ monitorHistoryRetentionDays = 30
// monitorRollupRetentionDays 日聚合保留天数。
// 日聚合行由 RunDailyMaintenance 在超过该窗口后软删。
monitorRollupRetentionDays = 30
diff --git a/backend/migrations/129_seed_claude_code_template.sql b/backend/migrations/129_seed_claude_code_template.sql
new file mode 100644
index 00000000..d9b062c9
--- /dev/null
+++ b/backend/migrations/129_seed_claude_code_template.sql
@@ -0,0 +1,38 @@
+-- Migration: 129_seed_claude_code_template
+-- 内置「Claude Code 伪装」请求模板,覆盖 Anthropic 上游对官方 CLI 客户端的所有验证项:
+-- 1) User-Agent / X-App / anthropic-beta / anthropic-version 等头
+-- 2) system 数组首项与官方 system prompt 字面一致(Dice >= 0.5)
+-- 3) metadata.user_id 满足 ParseMetadataUserID — 这里用 legacy 格式(user_<64hex>_account__session_<36char>)
+-- 避免新版 JSON 字符串内嵌 JSON 在编辑器里出现一长串 \" 转义,便于用户阅读。
+--
+-- ON CONFLICT DO NOTHING:已部署环境(手动建过模板)跑此 migration 不会重复 / 覆盖。
+-- 用户可自行编辑后续覆盖此 seed;CC 升大版时再起一条 migration 提供新模板,不动用户的旧模板。
+
+INSERT INTO channel_monitor_request_templates (
+ name, provider, description, extra_headers, body_override_mode, body_override
+)
+VALUES (
+ 'Claude Code 伪装',
+ 'anthropic',
+ '完整模拟 Claude Code 2.1.114 客户端:UA + anthropic-beta + system + metadata.user_id 全部对齐,绕过 Anthropic 上游 ''Claude Code only'' 限制(如 Max 套餐)。',
+ '{
+ "User-Agent": "claude-cli/2.1.114 (external, sdk-cli)",
+ "X-App": "cli",
+ "anthropic-version": "2023-06-01",
+ "anthropic-beta": "claude-code-20250219,interleaved-thinking-2025-05-14,context-management-2025-06-27,prompt-caching-scope-2026-01-05,advisor-tool-2026-03-01",
+ "anthropic-dangerous-direct-browser-access": "true"
+ }'::jsonb,
+ 'merge',
+ '{
+ "system": [
+ {
+ "type": "text",
+ "text": "You are Claude Code, Anthropic''s official CLI for Claude."
+ }
+ ],
+ "metadata": {
+ "user_id": "user_0000000000000000000000000000000000000000000000000000000000000000_account_00000000-0000-0000-0000-000000000000_session_00000000-0000-0000-0000-000000000000"
+ }
+ }'::jsonb
+)
+ON CONFLICT (provider, name) DO NOTHING;
diff --git a/frontend/src/components/admin/monitor/MonitorAdvancedRequestConfig.vue b/frontend/src/components/admin/monitor/MonitorAdvancedRequestConfig.vue
index 24827316..fb503a49 100644
--- a/frontend/src/components/admin/monitor/MonitorAdvancedRequestConfig.vue
+++ b/frontend/src/components/admin/monitor/MonitorAdvancedRequestConfig.vue
@@ -38,12 +38,24 @@
-
{{ t('admin.channelMonitor.advanced.bodyJson') }}
+
+ {{ t('admin.channelMonitor.advanced.bodyJson') }}
+
+ {{ t('admin.channelMonitor.advanced.bodyJsonFormat') }}
+
+
{{ bodyError }}
@@ -158,6 +170,25 @@ function commitBody() {
}
}
+function formatBody() {
+ const trimmed = bodyText.value.trim()
+ if (trimmed === '') return
+ try {
+ const parsed = JSON.parse(trimmed)
+ bodyText.value = JSON.stringify(parsed, null, 2)
+ bodyError.value = ''
+ // 同步把校验过的对象提交,避免格式化后焦点未移走时父组件读到旧值
+ if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) {
+ emit('update:bodyOverride', parsed as Record
)
+ }
+ } catch (e) {
+ bodyError.value =
+ t('admin.channelMonitor.advanced.bodyJsonError') +
+ ': ' +
+ (e instanceof Error ? e.message : String(e))
+ }
+}
+
function serializeBody(body: Record | null): string {
if (!body || Object.keys(body).length === 0) return ''
return JSON.stringify(body, null, 2)
diff --git a/frontend/src/components/user/monitor/MonitorTimeline.vue b/frontend/src/components/user/monitor/MonitorTimeline.vue
index b4d0c151..2445bc51 100644
--- a/frontend/src/components/user/monitor/MonitorTimeline.vue
+++ b/frontend/src/components/user/monitor/MonitorTimeline.vue
@@ -59,19 +59,21 @@ interface Bar {
title: string
}
+// 4 级高度 + 颜色双重编码:高=好+绿,短=坏+红,灰=未测试。
+// 长绿(正常) > 中黄(降级) > 短红(失败/系统错误) > 很短灰(未测试)。
const STATUS_HEIGHT: Record = {
operational: 100,
- degraded: 70,
- failed: 55,
+ degraded: 65,
+ failed: 35,
error: 35,
- empty: 20,
+ empty: 15,
}
const STATUS_COLOR: Record = {
operational: 'bg-emerald-500',
degraded: 'bg-amber-500',
failed: 'bg-red-500',
- error: 'bg-gray-400 dark:bg-dark-500',
+ error: 'bg-red-500',
empty: 'bg-gray-300 dark:bg-dark-600',
}
diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts
index 1d49efbf..d26462b3 100644
--- a/frontend/src/i18n/locales/en.ts
+++ b/frontend/src/i18n/locales/en.ts
@@ -2173,6 +2173,7 @@ export default {
bodyModeHintMerge: 'Shallow-merge with the default body; user fields win but model / messages / contents are protected (use Replace to change those).',
bodyModeHintReplace: 'Use the JSON below as the complete body. Challenge validation is skipped; HTTP 2xx + non-empty response text is treated as operational.',
bodyJson: 'Body JSON',
+ bodyJsonFormat: 'Format',
bodyJsonHint: 'Parsed on blur. Empty means no override.',
bodyJsonError: 'JSON parse failed',
bodyJsonObjectError: 'Body must be a JSON object (no arrays or primitives)'
diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts
index fb84dfd2..9ac1c20f 100644
--- a/frontend/src/i18n/locales/zh.ts
+++ b/frontend/src/i18n/locales/zh.ts
@@ -2252,6 +2252,7 @@ export default {
bodyModeHintMerge: '与默认请求体浅合并,用户字段优先;但 model / messages / contents 会被保护不允许覆盖(动这些字段请用「覆盖」模式)。',
bodyModeHintReplace: '完全用下方 JSON 作为请求体。注意:此模式下跳过 challenge 校验,改为 HTTP 2xx + 响应文本非空即视为可用。',
bodyJson: 'Body JSON',
+ bodyJsonFormat: '格式化',
bodyJsonHint: '失焦时自动解析校验。留空等价于没有覆盖。',
bodyJsonError: 'JSON 解析失败',
bodyJsonObjectError: '请求体必须是一个 JSON 对象(不能是数组或基本类型)'
--
GitLab
From e1193212b50c292468a284bb397b7f6fa0ded466 Mon Sep 17 00:00:00 2001
From: erio
Date: Tue, 21 Apr 2026 15:37:57 +0800
Subject: [PATCH 095/261] feat(monitor): switch headers input to key-value rows
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
- AdvancedRequestConfig 把 headers textarea 换成行式:每行 name 输入 + value 输入
+ 删除按钮,底部「+ 添加 Header」。直观区分名/值,不用再一行 "Key: Value" 自己拆。
- 校验下放到行级:name 含空格或冒号才报错,未填仅占位不报错(避免输入时频繁红字)。
- 外部 props 同值不回写,避免 commit 后行被重排。
- chore: 移除 CLAUDE.md 里 silentflower remote 行(不再追踪)。
---
.../monitor/MonitorAdvancedRequestConfig.vue | 137 +++++++++++++-----
frontend/src/i18n/locales/en.ts | 6 +-
frontend/src/i18n/locales/zh.ts | 6 +-
3 files changed, 111 insertions(+), 38 deletions(-)
diff --git a/frontend/src/components/admin/monitor/MonitorAdvancedRequestConfig.vue b/frontend/src/components/admin/monitor/MonitorAdvancedRequestConfig.vue
index fb503a49..0d6b4ace 100644
--- a/frontend/src/components/admin/monitor/MonitorAdvancedRequestConfig.vue
+++ b/frontend/src/components/admin/monitor/MonitorAdvancedRequestConfig.vue
@@ -1,15 +1,52 @@
-
+
{{ t('admin.channelMonitor.advanced.headers') }}
-
+
+
+
+
+
+
+ {{ t('admin.channelMonitor.advanced.headerAddRow') }}
+
+
{{ headersError }}
{{ t('admin.channelMonitor.advanced.headersHint') }}
@@ -85,51 +122,79 @@ const emit = defineEmits<{
const { t } = useI18n()
-// ---- Headers textarea (Key: Value per line) ----
-const headersText = ref(serializeHeaders(props.extraHeaders))
+// ---- Headers key-value rows ----
+interface HeaderRow {
+ name: string
+ value: string
+}
+
+const headerRows = ref(toRows(props.extraHeaders))
const headersError = ref('')
watch(
() => props.extraHeaders,
(v) => {
- // 外部重置时(如切换平台 / 应用模板)同步文本
- headersText.value = serializeHeaders(v)
+ // 外部重置时(切换平台 / 应用模板)同步行。
+ // 同值不回写,避免每次 commit 都把行重排。
+ if (!isSameHeaderMap(toMap(headerRows.value), v)) {
+ headerRows.value = toRows(v)
+ }
headersError.value = ''
},
)
+function toRows(h: Record): HeaderRow[] {
+ const entries = Object.entries(h || {})
+ if (entries.length === 0) return [{ name: '', value: '' }]
+ return entries.map(([name, value]) => ({ name, value }))
+}
+
+function toMap(rows: HeaderRow[]): Record {
+ const out: Record = {}
+ for (const row of rows) {
+ const name = row.name.trim()
+ if (name === '') continue
+ out[name] = row.value
+ }
+ return out
+}
+
+function isSameHeaderMap(a: Record, b: Record): boolean {
+ const ak = Object.keys(a)
+ const bk = Object.keys(b || {})
+ if (ak.length !== bk.length) return false
+ for (const k of ak) {
+ if (a[k] !== b[k]) return false
+ }
+ return true
+}
+
function commitHeaders() {
- const parsed = parseHeaders(headersText.value)
- if (parsed.error) {
- headersError.value = parsed.error
- return
+ // 空白 name + 空白 value 的行允许保留作为"占位新行",不报错;
+ // name 非空但 value 为空(或反之)都视为用户正在编辑,同样不报错。
+ // 只在 name 里含冒号这种明显不合法时兜一下。
+ for (const row of headerRows.value) {
+ const name = row.name.trim()
+ if (name === '') continue
+ if (name.includes(':') || /\s/.test(name)) {
+ headersError.value = t('admin.channelMonitor.advanced.headerNameInvalid', { name })
+ return
+ }
}
headersError.value = ''
- emit('update:extraHeaders', parsed.headers)
+ emit('update:extraHeaders', toMap(headerRows.value))
}
-function serializeHeaders(h: Record): string {
- return Object.entries(h || {})
- .map(([k, v]) => `${k}: ${v}`)
- .join('\n')
+function addRow() {
+ headerRows.value.push({ name: '', value: '' })
}
-function parseHeaders(raw: string): { headers: Record; error: string } {
- const result: Record = {}
- const lines = raw.split(/\r?\n/).map((l) => l.trim()).filter(Boolean)
- for (const line of lines) {
- const idx = line.indexOf(':')
- if (idx <= 0) {
- return { headers: {}, error: t('admin.channelMonitor.advanced.headersParseError', { line }) }
- }
- const key = line.slice(0, idx).trim()
- const value = line.slice(idx + 1).trim()
- if (!key) {
- return { headers: {}, error: t('admin.channelMonitor.advanced.headersParseError', { line }) }
- }
- result[key] = value
+function removeRow(index: number) {
+ headerRows.value.splice(index, 1)
+ if (headerRows.value.length === 0) {
+ headerRows.value.push({ name: '', value: '' })
}
- return { headers: result, error: '' }
+ commitHeaders()
}
// ---- Body mode + JSON ----
diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts
index d26462b3..eb401ae2 100644
--- a/frontend/src/i18n/locales/en.ts
+++ b/frontend/src/i18n/locales/en.ts
@@ -2163,7 +2163,11 @@ export default {
sectionHint: 'Customize request headers and body to bypass upstream client-detection (e.g. "only Claude Code clients allowed").',
headers: 'Custom request headers',
headersPlaceholder: 'User-Agent: claude-cli/1.0.83 (external, cli)\nx-app: cli\nanthropic-beta: claude-code-20250219',
- headersHint: 'One Key: Value per line; merged on top of adapter defaults (user wins). Hop-by-hop headers (Host / Content-Length / ...) are ignored.',
+ headerNamePlaceholder: 'Header name',
+ headerValuePlaceholder: 'Value',
+ headerAddRow: 'Add header',
+ headerNameInvalid: 'Header name cannot contain whitespace or colon: {name}',
+ headersHint: 'Merged on top of adapter defaults (user wins). Hop-by-hop headers (Host / Content-Length / ...) are ignored.',
headersParseError: 'Cannot parse line: {line}',
bodyMode: 'Body handling',
bodyModeOff: 'Default',
diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts
index 9ac1c20f..d38b5034 100644
--- a/frontend/src/i18n/locales/zh.ts
+++ b/frontend/src/i18n/locales/zh.ts
@@ -2242,7 +2242,11 @@ export default {
sectionHint: '自定义请求头和请求体,用于突破上游的客户端识别限制(如仅允许 Claude Code 客户端)。',
headers: '自定义请求头',
headersPlaceholder: 'User-Agent: claude-cli/1.0.83 (external, cli)\nx-app: cli\nanthropic-beta: claude-code-20250219',
- headersHint: '每行一对 Key: Value;会与默认请求头合并,用户值优先。hop-by-hop 类 header(Host/Content-Length/...)会被忽略。',
+ headerNamePlaceholder: 'Header 名',
+ headerValuePlaceholder: 'Value',
+ headerAddRow: '添加 Header',
+ headerNameInvalid: 'Header 名不能包含空格或冒号:{name}',
+ headersHint: '与默认请求头合并,用户值优先。hop-by-hop 类 header(Host/Content-Length/...)会被忽略。',
headersParseError: '无法解析这一行:{line}',
bodyMode: '请求体处理',
bodyModeOff: '默认',
--
GitLab
From c2f9ad7a217eb6f19757de6ae1ae30d92c3c2812 Mon Sep 17 00:00:00 2001
From: erio
Date: Wed, 22 Apr 2026 19:17:08 +0800
Subject: [PATCH 096/261] refactor(channel-monitor): event-driven scheduler +
sidebar cleanup
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
后端 - ChannelMonitorRunner 重写为事件驱动调度
- 删除 5 秒轮询架构(每次 ListEnabled + listDueForCheck 全表扫描),
改为每个 enabled monitor 一个独立 goroutine + ticker(按各自 IntervalSeconds)
- 新增 MonitorScheduler 接口,service 通过 setter 注入避免依赖环
- ChannelMonitorService.Create/Update/Delete 直接回调 scheduler.Schedule/Unschedule
- runner.Start 一次性加载所有 enabled monitor 建立任务表
- 新建/启用立即触发首次检测,禁用/删除即时撤销 ticker
- 保留 inFlight 去重 + pond 池并发上限 + 全局开关每次 fire 实时校验
- 删除 listDueForCheck / monitorTickerInterval / monitorListDueTimeout
前端 - 可用渠道改为用户级菜单
- 从 adminNavItems 移除 /available-channels(admin 主菜单不再重复出现)
- buildSelfNavItems 始终包含可用渠道入口,普通用户主菜单和
管理员"我的账户"区都能看到
---
.../internal/service/channel_monitor_const.go | 6 +-
.../service/channel_monitor_runner.go | 215 +++++++++++++-----
.../service/channel_monitor_service.go | 41 ++--
backend/internal/service/wire.go | 6 +-
4 files changed, 183 insertions(+), 85 deletions(-)
diff --git a/backend/internal/service/channel_monitor_const.go b/backend/internal/service/channel_monitor_const.go
index 2fc45639..2e1614f7 100644
--- a/backend/internal/service/channel_monitor_const.go
+++ b/backend/internal/service/channel_monitor_const.go
@@ -28,8 +28,8 @@ const (
monitorMaintenanceMaxDaysPerRun = 35
// monitorWorkerConcurrency 调度器并发执行的监控数(pond 池容量)。
monitorWorkerConcurrency = 5
- // monitorTickerInterval 调度器扫描"到期监控"的间隔。
- monitorTickerInterval = 5 * time.Second
+ // monitorStartupLoadTimeout Start 时一次性加载所有 enabled monitor 的总超时。
+ monitorStartupLoadTimeout = 10 * time.Second
// monitorMinIntervalSeconds / monitorMaxIntervalSeconds 用户配置的检测间隔上下限。
monitorMinIntervalSeconds = 15
monitorMaxIntervalSeconds = 3600
@@ -86,8 +86,6 @@ const (
// monitorChallengeMaxTokens 单次 challenge 请求的 max_tokens(足够回答个位数算术)。
monitorChallengeMaxTokens = 50
- // monitorListDueTimeout tickDueChecks 查询到期监控的总超时。
- monitorListDueTimeout = 10 * time.Second
// monitorRunOneBuffer runOne 的总超时缓冲(除请求超时与 ping 超时外的额外裕量)。
monitorRunOneBuffer = 10 * time.Second
diff --git a/backend/internal/service/channel_monitor_runner.go b/backend/internal/service/channel_monitor_runner.go
index 21dca8ab..be30aec2 100644
--- a/backend/internal/service/channel_monitor_runner.go
+++ b/backend/internal/service/channel_monitor_runner.go
@@ -9,121 +9,215 @@ import (
"github.com/alitto/pond/v2"
)
+// MonitorScheduler 调度器接口,供 ChannelMonitorService 在 CRUD 时回调,
+// 用 setter 注入避免 service ↔ runner 的 wire 依赖环。
+type MonitorScheduler interface {
+ // Schedule 为指定监控创建(或重置)独立定时任务。
+ // 当 m.Enabled=false 时等同于 Unschedule(m.ID)。
+ Schedule(m *ChannelMonitor)
+ // Unschedule 取消指定监控的定时任务(若存在)。
+ Unschedule(id int64)
+}
+
// ChannelMonitorRunner 渠道监控调度器。
//
-// 职责:
-// - 每 monitorTickerInterval 扫描一次"到期需要检测"的监控
-// - 通过 pond 池(容量 monitorWorkerConcurrency)异步执行检测
-// - Stop 时优雅关闭:池 drain + ticker.Stop + wg.Wait
-//
-// 历史清理与日聚合维护不再由 runner 负责,由 OpsCleanupService 的统一 cron
-// 在凌晨触发 ChannelMonitorService.RunDailyMaintenance(复用 leader lock + heartbeat)。
+// 设计:
+// - 每个 enabled monitor 对应一个独立 goroutine + ticker(按各自 IntervalSeconds)
+// - Start 时一次性加载所有 enabled monitor 并为每个建立任务
+// - Service 在 Create/Update/Delete 后通过 MonitorScheduler 接口回调,
+// 即时重建/取消对应任务(无需轮询 DB)
+// - 实际 HTTP 检测交给 pond 池(容量 monitorWorkerConcurrency),
+// 防止突发并发拖垮上游
//
-// 定时任务维护:删除/创建/编辑 monitor 无需显式 reload,每个 tick 都会重新查 DB
-// (ListEnabled + listDueForCheck),新 monitor 的 LastCheckedAt 为 nil 天然立即到期,
-// 被删除的 monitor 自然不再返回,interval 变化下次 tick 自动按新值判定。
+// 历史清理与日聚合维护由 OpsCleanupService 的 cron 触发
+// ChannelMonitorService.RunDailyMaintenance(复用 leader lock + heartbeat),
+// 不在 runner 职责内。
type ChannelMonitorRunner struct {
svc *ChannelMonitorService
settingService *SettingService
- pool pond.Pool
- stopCh chan struct{}
- once sync.Once
- wg sync.WaitGroup
+ pool pond.Pool
+ parentCtx context.Context
+ parentCancel context.CancelFunc
- // inFlight 跟踪正在执行的 monitor.ID。tickDueChecks 调度前会检查避免重复提交,
+ mu sync.Mutex
+ tasks map[int64]*scheduledMonitor
+ wg sync.WaitGroup
+ started bool
+ stopped bool
+
+ // inFlight 跟踪正在执行的 monitor.ID。fire 调度前会检查避免重复提交,
// 防止单次检测耗时 > interval 时同一 monitor 被并发执行。
inFlight map[int64]struct{}
inFlightMu sync.Mutex
}
-// NewChannelMonitorRunner 构造调度器。Start 在 wire 中调用。
-// settingService 用于在每次 tick 前读取功能开关;传 nil 时视为总是启用(兼容测试)。
+// scheduledMonitor 单个监控的运行时上下文。
+type scheduledMonitor struct {
+ id int64
+ name string
+ interval time.Duration
+ cancel context.CancelFunc
+}
+
+// NewChannelMonitorRunner 构造调度器。Start 在 wire 中调用一次。
+// settingService 用于在每次 fire 前读取功能开关;传 nil 时视为总是启用(兼容测试)。
func NewChannelMonitorRunner(svc *ChannelMonitorService, settingService *SettingService) *ChannelMonitorRunner {
+ ctx, cancel := context.WithCancel(context.Background())
return &ChannelMonitorRunner{
svc: svc,
settingService: settingService,
- stopCh: make(chan struct{}),
+ parentCtx: ctx,
+ parentCancel: cancel,
+ tasks: make(map[int64]*scheduledMonitor),
inFlight: make(map[int64]struct{}),
}
}
-// Start 启动 ticker + worker pool。
+// Start 加载所有 enabled monitor 并为每个建立独立定时任务。
// 调用方需保证只调一次(wire ProvideChannelMonitorRunner 内只调一次)。
func (r *ChannelMonitorRunner) Start() {
if r == nil || r.svc == nil {
return
}
- // 容量 5 的 pond 池:超出时调用方等待,避免调度堆积无限增长。
+ r.mu.Lock()
+ if r.started || r.stopped {
+ r.mu.Unlock()
+ return
+ }
+ r.started = true
r.pool = pond.NewPool(monitorWorkerConcurrency)
+ r.mu.Unlock()
+ ctx, cancel := context.WithTimeout(context.Background(), monitorStartupLoadTimeout)
+ defer cancel()
+ enabled, err := r.svc.ListEnabledMonitors(ctx)
+ if err != nil {
+ slog.Error("channel_monitor: load enabled monitors failed at startup", "error", err)
+ return
+ }
+ for _, m := range enabled {
+ r.Schedule(m)
+ }
+ slog.Info("channel_monitor: runner started", "scheduled_tasks", len(enabled))
+}
+
+// Schedule 为指定监控创建(或重置)独立定时任务。
+// - m.Enabled=false → 等同于 Unschedule(m.ID)
+// - 已存在的任务会先被取消再重建(适用于 IntervalSeconds 变更场景)
+// - 新任务立即触发首次检测,之后按 IntervalSeconds 周期触发
+func (r *ChannelMonitorRunner) Schedule(m *ChannelMonitor) {
+ if r == nil || m == nil {
+ return
+ }
+ if !m.Enabled {
+ r.Unschedule(m.ID)
+ return
+ }
+ interval := time.Duration(m.IntervalSeconds) * time.Second
+ if interval <= 0 {
+ slog.Warn("channel_monitor: skip schedule for invalid interval",
+ "monitor_id", m.ID, "interval_seconds", m.IntervalSeconds)
+ return
+ }
+
+ r.mu.Lock()
+ if r.stopped || !r.started {
+ r.mu.Unlock()
+ return
+ }
+ if existing, ok := r.tasks[m.ID]; ok {
+ existing.cancel()
+ }
+ ctx, cancel := context.WithCancel(r.parentCtx)
+ task := &scheduledMonitor{
+ id: m.ID,
+ name: m.Name,
+ interval: interval,
+ cancel: cancel,
+ }
+ r.tasks[m.ID] = task
r.wg.Add(1)
- go r.dueCheckLoop()
+ r.mu.Unlock()
+
+ go r.runScheduled(ctx, task)
+}
+
+// Unschedule 取消指定监控的定时任务(若存在)。
+// 已经在执行中的检测会通过 ctx 取消信号传递。
+func (r *ChannelMonitorRunner) Unschedule(id int64) {
+ if r == nil {
+ return
+ }
+ r.mu.Lock()
+ task, ok := r.tasks[id]
+ if ok {
+ delete(r.tasks, id)
+ }
+ r.mu.Unlock()
+ if ok {
+ task.cancel()
+ }
}
-// Stop 优雅停止:close stopCh -> 等待 loop 退出 -> 池 drain。
+// Stop 优雅停止:取消所有任务、关闭池。
func (r *ChannelMonitorRunner) Stop() {
if r == nil {
return
}
- r.once.Do(func() {
- close(r.stopCh)
- })
+ r.mu.Lock()
+ if r.stopped {
+ r.mu.Unlock()
+ return
+ }
+ r.stopped = true
+ r.parentCancel()
+ r.tasks = nil
+ r.mu.Unlock()
+
r.wg.Wait()
if r.pool != nil {
r.pool.StopAndWait()
}
}
-// dueCheckLoop 每 monitorTickerInterval 扫描一次"到期监控",提交到池。
-func (r *ChannelMonitorRunner) dueCheckLoop() {
+// runScheduled 单个监控的循环:立即触发首次(满足"新建/启用即跑"),
+// 之后按 interval 周期触发;ctx 取消即退出。
+func (r *ChannelMonitorRunner) runScheduled(ctx context.Context, task *scheduledMonitor) {
defer r.wg.Done()
- ticker := time.NewTicker(monitorTickerInterval)
- defer ticker.Stop()
+ r.fire(ctx, task)
+ ticker := time.NewTicker(task.interval)
+ defer ticker.Stop()
for {
select {
- case <-r.stopCh:
+ case <-ctx.Done():
return
case <-ticker.C:
- r.tickDueChecks()
+ r.fire(ctx, task)
}
}
}
-// tickDueChecks 一次扫描:查询到期监控并逐个提交到池。
-// 已在执行的 monitor 会被跳过(防止单次检测耗时 > interval 时重复调度)。
-// 池满时使用 TrySubmit 跳过(不能阻塞 ticker),同时立即释放已占用的 inFlight 槽。
-// 当功能开关关闭时直接返回——管理员可以动态禁用模块,runner 不会拉取 DB。
-func (r *ChannelMonitorRunner) tickDueChecks() {
- ctx, cancel := context.WithTimeout(context.Background(), monitorListDueTimeout)
- defer cancel()
-
+// fire 提交一次检测到 worker 池。功能开关关闭时跳过本次(不取消任务,
+// 重新启用时立即恢复);池满或重复在飞时也跳过。
+func (r *ChannelMonitorRunner) fire(ctx context.Context, task *scheduledMonitor) {
if r.settingService != nil && !r.settingService.GetChannelMonitorRuntime(ctx).Enabled {
return
}
-
- due, err := r.svc.listDueForCheck(ctx)
- if err != nil {
- slog.Warn("channel_monitor: list due failed", "error", err)
+ if !r.tryAcquireInFlight(task.id) {
+ slog.Debug("channel_monitor: skip already in-flight",
+ "monitor_id", task.id, "name", task.name)
return
}
- for _, m := range due {
- monitor := m
- if !r.tryAcquireInFlight(monitor.ID) {
- slog.Debug("channel_monitor: skip already in-flight",
- "monitor_id", monitor.ID, "name", monitor.Name)
- continue
- }
- if _, ok := r.pool.TrySubmit(func() {
- r.runOne(monitor.ID, monitor.Name)
- }); !ok {
- // 池满:丢弃本次检测,但必须释放已占用的 inFlight 槽,否则该 monitor 会被永久卡住。
- r.releaseInFlight(monitor.ID)
- slog.Warn("channel_monitor: worker pool full, skip submission",
- "monitor_id", monitor.ID, "name", monitor.Name)
- }
+ if _, ok := r.pool.TrySubmit(func() {
+ r.runOne(task.id, task.name)
+ }); !ok {
+ // 池满:丢弃本次检测,但必须释放已占用的 inFlight 槽,否则该 monitor 会被永久卡住。
+ r.releaseInFlight(task.id)
+ slog.Warn("channel_monitor: worker pool full, skip submission",
+ "monitor_id", task.id, "name", task.name)
}
}
@@ -148,11 +242,7 @@ func (r *ChannelMonitorRunner) releaseInFlight(id int64) {
// runOne 执行单个监控的检测。所有错误只记日志,不熔断。
// 任务结束时(含 panic recover)必须释放 in-flight 槽。
-//
-// 单次解密路径:调 RunCheckByID,内部统一 Get + APIKeyDecryptFailed 判定 + 跑检测,
-// 避免 runner 自己再 Get 一次造成密文二次解密。
func (r *ChannelMonitorRunner) runOne(id int64, name string) {
- // 单次任务上限 = 请求超时 + ping + 一些缓冲。
ctx, cancel := context.WithTimeout(context.Background(), monitorRequestTimeout+monitorPingTimeout+monitorRunOneBuffer)
defer cancel()
@@ -166,7 +256,6 @@ func (r *ChannelMonitorRunner) runOne(id int64, name string) {
}()
if _, err := r.svc.RunCheck(ctx, id); err != nil {
- // ErrChannelMonitorAPIKeyDecryptFailed 是预期可恢复错误,降为 Warn 即可。
slog.Warn("channel_monitor: run check failed",
"monitor_id", id, "name", name, "error", err)
}
diff --git a/backend/internal/service/channel_monitor_service.go b/backend/internal/service/channel_monitor_service.go
index ec1107a3..7050e141 100644
--- a/backend/internal/service/channel_monitor_service.go
+++ b/backend/internal/service/channel_monitor_service.go
@@ -61,6 +61,9 @@ type ChannelMonitorRepository interface {
type ChannelMonitorService struct {
repo ChannelMonitorRepository
encryptor SecretEncryptor
+ // scheduler 由 wire 通过 SetScheduler 注入;CRUD 后调用对应钩子即时同步任务。
+ // 测试或未注入场景下保持 nil,所有钩子调用变为 no-op。
+ scheduler MonitorScheduler
}
// NewChannelMonitorService 创建渠道监控服务实例。
@@ -136,6 +139,9 @@ func (s *ChannelMonitorService) Create(ctx context.Context, p ChannelMonitorCrea
// 不再调 s.Get 重走解密链:已知刚加密的明文,直接构造响应。
// 这样可避免 SecretEncryptor 解密失败时 APIKey 被静默清空的问题(见 Fix 4)。
m.APIKey = strings.TrimSpace(p.APIKey)
+ if s.scheduler != nil {
+ s.scheduler.Schedule(m)
+ }
return m, nil
}
@@ -184,6 +190,11 @@ func (s *ChannelMonitorService) Update(ctx context.Context, id int64, p ChannelM
} else {
s.decryptInPlace(existing)
}
+ if s.scheduler != nil {
+ // Schedule 内部根据 Enabled 自动选择 Unschedule 或重建任务,
+ // IntervalSeconds 变化也会被自然吸收(旧 task 取消 + 新 task 用新 interval)。
+ s.scheduler.Schedule(existing)
+ }
return existing, nil
}
@@ -209,6 +220,9 @@ func (s *ChannelMonitorService) Delete(ctx context.Context, id int64) error {
if err := s.repo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete channel monitor: %w", err)
}
+ if s.scheduler != nil {
+ s.scheduler.Unschedule(id)
+ }
return nil
}
@@ -306,29 +320,24 @@ func (s *ChannelMonitorService) runChecksConcurrent(ctx context.Context, m *Chan
return results
}
-// ---------- 调度器内部 ----------
+// ---------- 调度器协作 ----------
-// listDueForCheck 返回需要立即检测的监控列表:
-// enabled=true AND (last_checked_at IS NULL OR last_checked_at + interval <= now)。
-// 实现下沉到 repository(用 SQL 表达式比较),减少应用层数据传输。
-func (s *ChannelMonitorService) listDueForCheck(ctx context.Context) ([]*ChannelMonitor, error) {
+// SetScheduler 由 wire 在 runner 构造后注入,用于在 CRUD 时即时同步任务表。
+// 通过 setter 注入避免 service ↔ runner 的依赖环。
+func (s *ChannelMonitorService) SetScheduler(sched MonitorScheduler) {
+ s.scheduler = sched
+}
+
+// ListEnabledMonitors 返回所有 enabled=true 的监控(解密后),供 runner 启动时建立任务表。
+func (s *ChannelMonitorService) ListEnabledMonitors(ctx context.Context) ([]*ChannelMonitor, error) {
all, err := s.repo.ListEnabled(ctx)
if err != nil {
return nil, err
}
- now := time.Now()
- due := make([]*ChannelMonitor, 0, len(all))
for _, m := range all {
- if m.LastCheckedAt == nil {
- due = append(due, m)
- continue
- }
- nextAt := m.LastCheckedAt.Add(time.Duration(m.IntervalSeconds) * time.Second)
- if !nextAt.After(now) {
- due = append(due, m)
- }
+ s.decryptInPlace(m)
}
- return due, nil
+ return all, nil
}
// cleanupOldHistory 删除 monitorHistoryRetentionDays 天之前的明细历史记录。
diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go
index 3148f865..ab2802fd 100644
--- a/backend/internal/service/wire.go
+++ b/backend/internal/service/wire.go
@@ -503,10 +503,12 @@ func ProvideChannelMonitorService(
}
// ProvideChannelMonitorRunner 创建并启动渠道监控调度器。
-// Runner.Stop 由 cleanup function 调用。
-// settingService 用于 runner 每个 tick 读取功能开关。
+// 通过 SetScheduler 注入回 service 后再 Start,确保启动时加载所有 enabled monitor,
+// 后续 CRUD 也能即时同步任务表。Runner.Stop 由 cleanup function 调用。
+// settingService 用于 runner 每次 fire 读取功能开关。
func ProvideChannelMonitorRunner(svc *ChannelMonitorService, settingService *SettingService) *ChannelMonitorRunner {
r := NewChannelMonitorRunner(svc, settingService)
+ svc.SetScheduler(r)
r.Start()
return r
}
--
GitLab
From c46744f3662fb3991758af09a763c1970126f962 Mon Sep 17 00:00:00 2001
From: erio
Date: Wed, 22 Apr 2026 20:08:31 +0800
Subject: [PATCH 097/261] refactor(channel-monitor): tighten runner lifecycle +
add unit tests
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
- pool 改在 NewChannelMonitorRunner 构造时初始化,消除 Start 在 mu 内
赋值、fire/Stop 在 mu 外读取的竞态隐患
- Schedule 在 !started 时由静默 return 改为 slog.Warn,错过的调度可见
- Schedule 在 interval<=0 时升为 slog.Error:Create/Update validateInterval
已保证不可达,真触发即数据/校验链 bug
- 抽出 monitorRunnerSvc 内部接口(仅 ListEnabledMonitors+RunCheck),
生产 *ChannelMonitorService 自然满足;runner 单元测试可注入轻量 stub
- 新增 channel_monitor_runner_test.go(10 个用例,//go:build unit):
覆盖 Schedule/Unschedule/Start/Stop 生命周期、in-flight 槽对称释放、
Stop 等待正在执行的 RunCheck 退出(无游离 goroutine)
启动失败的恢复策略:保持现状(log+return)。CLAUDE.md 明确"配置应保证启动
成功(必填项校验+正确数据校验)",validate{Provider,Interval,Endpoint,
APIKey,PrimaryModel} 已在 Create/Update 全部覆盖;DB 不可用是基础设施问题,
不该靠应用层无限重试兜底。
---
.../service/channel_monitor_runner.go | 43 ++-
.../service/channel_monitor_runner_test.go | 277 ++++++++++++++++++
2 files changed, 313 insertions(+), 7 deletions(-)
create mode 100644 backend/internal/service/channel_monitor_runner_test.go
diff --git a/backend/internal/service/channel_monitor_runner.go b/backend/internal/service/channel_monitor_runner.go
index be30aec2..08178bc6 100644
--- a/backend/internal/service/channel_monitor_runner.go
+++ b/backend/internal/service/channel_monitor_runner.go
@@ -19,6 +19,17 @@ type MonitorScheduler interface {
Unschedule(id int64)
}
+// monitorRunnerSvc 抽出 runner 实际依赖的两个 service 方法:
+// - 启动时加载 enabled monitor
+// - 每次 ticker 触发执行检测
+//
+// 用接口而非 *ChannelMonitorService 是为了让 runner 单元测试可注入轻量 stub,
+// 避免依赖完整的 repo + encryptor 链路。生产实现 *ChannelMonitorService 自然满足。
+type monitorRunnerSvc interface {
+ ListEnabledMonitors(ctx context.Context) ([]*ChannelMonitor, error)
+ RunCheck(ctx context.Context, id int64) ([]*CheckResult, error)
+}
+
// ChannelMonitorRunner 渠道监控调度器。
//
// 设计:
@@ -33,7 +44,7 @@ type MonitorScheduler interface {
// ChannelMonitorService.RunDailyMaintenance(复用 leader lock + heartbeat),
// 不在 runner 职责内。
type ChannelMonitorRunner struct {
- svc *ChannelMonitorService
+ svc monitorRunnerSvc
settingService *SettingService
pool pond.Pool
@@ -62,11 +73,20 @@ type scheduledMonitor struct {
// NewChannelMonitorRunner 构造调度器。Start 在 wire 中调用一次。
// settingService 用于在每次 fire 前读取功能开关;传 nil 时视为总是启用(兼容测试)。
+//
+// pool 在构造时即建好:避免 Start 在 mu 内赋值、fire/Stop 在 mu 外读取的竞态隐患,
+// 且 pond.NewPool 创建本身近似零开销,提前建池不会浪费资源。
func NewChannelMonitorRunner(svc *ChannelMonitorService, settingService *SettingService) *ChannelMonitorRunner {
+ return newChannelMonitorRunner(svc, settingService)
+}
+
+// newChannelMonitorRunner 内部构造,接受最小化接口,便于单元测试注入 stub。
+func newChannelMonitorRunner(svc monitorRunnerSvc, settingService *SettingService) *ChannelMonitorRunner {
ctx, cancel := context.WithCancel(context.Background())
return &ChannelMonitorRunner{
svc: svc,
settingService: settingService,
+ pool: pond.NewPool(monitorWorkerConcurrency),
parentCtx: ctx,
parentCancel: cancel,
tasks: make(map[int64]*scheduledMonitor),
@@ -86,7 +106,6 @@ func (r *ChannelMonitorRunner) Start() {
return
}
r.started = true
- r.pool = pond.NewPool(monitorWorkerConcurrency)
r.mu.Unlock()
ctx, cancel := context.WithTimeout(context.Background(), monitorStartupLoadTimeout)
@@ -116,16 +135,28 @@ func (r *ChannelMonitorRunner) Schedule(m *ChannelMonitor) {
}
interval := time.Duration(m.IntervalSeconds) * time.Second
if interval <= 0 {
- slog.Warn("channel_monitor: skip schedule for invalid interval",
+ // Create/Update 已通过 validateInterval 校验区间,正常路径不可能到这里。
+ // 真触发说明数据库中存在违反约束的数据或校验链路有 bug,记 Error 暴露问题。
+ slog.Error("channel_monitor: skip schedule for invalid interval",
"monitor_id", m.ID, "interval_seconds", m.IntervalSeconds)
return
}
r.mu.Lock()
- if r.stopped || !r.started {
+ if r.stopped {
r.mu.Unlock()
return
}
+ if !r.started {
+ // Start 之前调用 Schedule 通常意味着 wire 顺序错乱:
+ // 当前 wire 顺序是 SetScheduler → Start,CRUD 钩子最早也只能在请求到达时触发,
+ // 此时 Start 早已完成。出现此分支时把 monitor 信息打出来便于排查,
+ // 不入队、不缓存——交给运维通过重启或修复 wire 解决。
+ r.mu.Unlock()
+ slog.Warn("channel_monitor: schedule before runner started, skip",
+ "monitor_id", m.ID, "name", m.Name)
+ return
+ }
if existing, ok := r.tasks[m.ID]; ok {
existing.cancel()
}
@@ -176,9 +207,7 @@ func (r *ChannelMonitorRunner) Stop() {
r.mu.Unlock()
r.wg.Wait()
- if r.pool != nil {
- r.pool.StopAndWait()
- }
+ r.pool.StopAndWait()
}
// runScheduled 单个监控的循环:立即触发首次(满足"新建/启用即跑"),
diff --git a/backend/internal/service/channel_monitor_runner_test.go b/backend/internal/service/channel_monitor_runner_test.go
new file mode 100644
index 00000000..5eed3c20
--- /dev/null
+++ b/backend/internal/service/channel_monitor_runner_test.go
@@ -0,0 +1,277 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+)
+
+// stubMonitorSvc 实现 monitorRunnerSvc,用于隔离 runner 与真实 service/repo。
+type stubMonitorSvc struct {
+ enabled []*ChannelMonitor
+ runCount atomic.Int64
+ runCalled chan int64 // 每次 RunCheck 触发时 push 一次(缓冲足够大避免阻塞)
+ runErr error
+ listErr error
+ runHoldFor time.Duration // RunCheck 内额外阻塞的时长,用来测试 Stop 等待行为
+}
+
+func (s *stubMonitorSvc) ListEnabledMonitors(_ context.Context) ([]*ChannelMonitor, error) {
+ if s.listErr != nil {
+ return nil, s.listErr
+ }
+ return s.enabled, nil
+}
+
+func (s *stubMonitorSvc) RunCheck(ctx context.Context, id int64) ([]*CheckResult, error) {
+ s.runCount.Add(1)
+ if s.runCalled != nil {
+ select {
+ case s.runCalled <- id:
+ default:
+ }
+ }
+ if s.runHoldFor > 0 {
+ select {
+ case <-time.After(s.runHoldFor):
+ case <-ctx.Done():
+ }
+ }
+ return nil, s.runErr
+}
+
+func newRunnerForTest(svc monitorRunnerSvc) *ChannelMonitorRunner {
+ return newChannelMonitorRunner(svc, nil)
+}
+
+// 等待 condition 在 timeout 内变 true,否则 t.Fatalf。轮询 5ms 一次。
+func waitFor(t *testing.T, timeout time.Duration, msg string, cond func() bool) {
+ t.Helper()
+ deadline := time.Now().Add(timeout)
+ for time.Now().Before(deadline) {
+ if cond() {
+ return
+ }
+ time.Sleep(5 * time.Millisecond)
+ }
+ if !cond() {
+ t.Fatalf("waitFor timed out: %s", msg)
+ }
+}
+
+func runnerTaskCount(r *ChannelMonitorRunner) int {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ return len(r.tasks)
+}
+
+func runnerTaskPtr(r *ChannelMonitorRunner, id int64) *scheduledMonitor {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ return r.tasks[id]
+}
+
+// TestSchedule_AddsTaskAndFiresOnce 验证 Schedule 后立即触发一次首检测,并把任务记入 tasks 表。
+func TestSchedule_AddsTaskAndFiresOnce(t *testing.T) {
+ svc := &stubMonitorSvc{runCalled: make(chan int64, 4)}
+ r := newRunnerForTest(svc)
+ r.Start() // svc.enabled 为空,Start 立即完成
+
+ r.Schedule(&ChannelMonitor{ID: 1, Name: "m1", Enabled: true, IntervalSeconds: 60})
+
+ if got := runnerTaskCount(r); got != 1 {
+ t.Fatalf("expected 1 scheduled task, got %d", got)
+ }
+
+ select {
+ case id := <-svc.runCalled:
+ if id != 1 {
+ t.Fatalf("expected first fire for id=1, got %d", id)
+ }
+ case <-time.After(2 * time.Second):
+ t.Fatal("expected immediate first fire within 2s")
+ }
+
+ r.Stop()
+}
+
+// TestSchedule_ReplaceCancelsOldTask 验证对同一 id 二次 Schedule 会替换旧 task 实例。
+// (旧 goroutine 通过 ctx 取消退出;这里以 task 指针不同 + Stop 不超时作为证据。)
+func TestSchedule_ReplaceCancelsOldTask(t *testing.T) {
+ svc := &stubMonitorSvc{runCalled: make(chan int64, 8)}
+ r := newRunnerForTest(svc)
+ r.Start()
+
+ m := &ChannelMonitor{ID: 7, Name: "m7", Enabled: true, IntervalSeconds: 60}
+ r.Schedule(m)
+ first := runnerTaskPtr(r, 7)
+ if first == nil {
+ t.Fatal("first schedule did not register task")
+ }
+
+ r.Schedule(m)
+ second := runnerTaskPtr(r, 7)
+ if second == nil {
+ t.Fatal("second schedule did not register task")
+ }
+ if first == second {
+ t.Fatal("re-Schedule should create a new scheduledMonitor instance")
+ }
+
+ stoppedWithin(t, r, 3*time.Second)
+}
+
+// TestUnschedule_RemovesTask 验证 Unschedule 删除 task 并使对应 goroutine 退出。
+func TestUnschedule_RemovesTask(t *testing.T) {
+ svc := &stubMonitorSvc{runCalled: make(chan int64, 4)}
+ r := newRunnerForTest(svc)
+ r.Start()
+
+ r.Schedule(&ChannelMonitor{ID: 3, Enabled: true, IntervalSeconds: 60})
+ waitFor(t, time.Second, "task registered", func() bool { return runnerTaskCount(r) == 1 })
+
+ r.Unschedule(3)
+ if got := runnerTaskCount(r); got != 0 {
+ t.Fatalf("expected tasks empty after Unschedule, got %d", got)
+ }
+
+ stoppedWithin(t, r, 3*time.Second)
+}
+
+// TestSchedule_DisabledRedirectsToUnschedule 验证 Enabled=false 等同于 Unschedule。
+func TestSchedule_DisabledRedirectsToUnschedule(t *testing.T) {
+ svc := &stubMonitorSvc{runCalled: make(chan int64, 4)}
+ r := newRunnerForTest(svc)
+ r.Start()
+
+ r.Schedule(&ChannelMonitor{ID: 9, Enabled: true, IntervalSeconds: 60})
+ waitFor(t, time.Second, "task registered", func() bool { return runnerTaskCount(r) == 1 })
+
+ r.Schedule(&ChannelMonitor{ID: 9, Enabled: false, IntervalSeconds: 60})
+ if got := runnerTaskCount(r); got != 0 {
+ t.Fatalf("expected tasks empty after disabled re-Schedule, got %d", got)
+ }
+
+ stoppedWithin(t, r, 3*time.Second)
+}
+
+// TestSchedule_InvalidIntervalSkipped 验证 IntervalSeconds<=0 不会注册任务(防御性检查)。
+func TestSchedule_InvalidIntervalSkipped(t *testing.T) {
+ svc := &stubMonitorSvc{}
+ r := newRunnerForTest(svc)
+ r.Start()
+
+ r.Schedule(&ChannelMonitor{ID: 1, Enabled: true, IntervalSeconds: 0})
+ if got := runnerTaskCount(r); got != 0 {
+ t.Fatalf("expected no task for invalid interval, got %d", got)
+ }
+ r.Stop()
+}
+
+// TestSchedule_BeforeStartIsNoOp 验证 Start 之前调用 Schedule 不会注册任务。
+func TestSchedule_BeforeStartIsNoOp(t *testing.T) {
+ svc := &stubMonitorSvc{}
+ r := newRunnerForTest(svc)
+ // 故意不调用 Start
+
+ r.Schedule(&ChannelMonitor{ID: 1, Enabled: true, IntervalSeconds: 60})
+ if got := runnerTaskCount(r); got != 0 {
+ t.Fatalf("expected no task before Start, got %d", got)
+ }
+ r.Stop()
+}
+
+// TestStart_LoadsAllEnabledMonitors 验证 Start 会为 ListEnabledMonitors 返回的每条记录建立任务。
+func TestStart_LoadsAllEnabledMonitors(t *testing.T) {
+ svc := &stubMonitorSvc{
+ enabled: []*ChannelMonitor{
+ {ID: 1, Enabled: true, IntervalSeconds: 60},
+ {ID: 2, Enabled: true, IntervalSeconds: 60},
+ {ID: 3, Enabled: true, IntervalSeconds: 60},
+ },
+ }
+ r := newRunnerForTest(svc)
+ r.Start()
+ waitFor(t, 2*time.Second, "all 3 tasks scheduled", func() bool { return runnerTaskCount(r) == 3 })
+
+ stoppedWithin(t, r, 3*time.Second)
+}
+
+// TestStop_DrainsAllGoroutines 验证 Stop 会等待所有调度 goroutine 退出(无游离)。
+func TestStop_DrainsAllGoroutines(t *testing.T) {
+ svc := &stubMonitorSvc{}
+ r := newRunnerForTest(svc)
+ r.Start()
+
+ for id := int64(1); id <= 5; id++ {
+ r.Schedule(&ChannelMonitor{ID: id, Enabled: true, IntervalSeconds: 60})
+ }
+ waitFor(t, 2*time.Second, "5 tasks scheduled", func() bool { return runnerTaskCount(r) == 5 })
+
+ stoppedWithin(t, r, 3*time.Second)
+}
+
+// TestStop_WaitsForInFlightCheck 验证 Stop 会等待正在执行的 RunCheck 退出(pool.StopAndWait)。
+func TestStop_WaitsForInFlightCheck(t *testing.T) {
+ svc := &stubMonitorSvc{
+ runCalled: make(chan int64, 1),
+ runHoldFor: 200 * time.Millisecond,
+ }
+ r := newRunnerForTest(svc)
+ r.Start()
+ r.Schedule(&ChannelMonitor{ID: 1, Enabled: true, IntervalSeconds: 60})
+
+ select {
+ case <-svc.runCalled:
+ case <-time.After(2 * time.Second):
+ t.Fatal("first fire never happened")
+ }
+
+ start := time.Now()
+ stoppedWithin(t, r, 3*time.Second)
+ elapsed := time.Since(start)
+ // Stop 必须等待 in-flight check 跑完(runHoldFor=200ms),耗时下界约 100ms。
+ if elapsed < 100*time.Millisecond {
+ t.Fatalf("Stop returned too fast (%v); did not wait for in-flight check", elapsed)
+ }
+}
+
+// TestInFlight_PoolFullReleasesSlot 直接驱动 fire 路径,模拟 pool.TrySubmit 失败时 inFlight 必须释放。
+// 用一个小型 stub pool 替换 r.pool 不便(pond.Pool 是接口但 mock 麻烦),
+// 改为:占满 inFlight 后直接 fire,验证不会在 inFlight 空槽时永久卡住。
+func TestInFlight_AcquireReleaseSymmetric(t *testing.T) {
+ svc := &stubMonitorSvc{}
+ r := newRunnerForTest(svc)
+
+ if !r.tryAcquireInFlight(42) {
+ t.Fatal("first acquire should succeed")
+ }
+ if r.tryAcquireInFlight(42) {
+ t.Fatal("second acquire (no release) must fail")
+ }
+ r.releaseInFlight(42)
+ if !r.tryAcquireInFlight(42) {
+ t.Fatal("acquire after release should succeed")
+ }
+ r.releaseInFlight(42)
+}
+
+// stoppedWithin 在 timeout 内并行调用 Stop,超时则 Fatal。验证 Stop 不会阻塞。
+func stoppedWithin(t *testing.T, r *ChannelMonitorRunner, timeout time.Duration) {
+ t.Helper()
+ done := make(chan struct{})
+ var once sync.Once
+ go func() {
+ r.Stop()
+ once.Do(func() { close(done) })
+ }()
+ select {
+ case <-done:
+ case <-time.After(timeout):
+ t.Fatalf("Stop did not return within %s — leaked goroutine?", timeout)
+ }
+}
--
GitLab
From 654cfb6480577888dd6ac7d13abbb0cf5ed400c3 Mon Sep 17 00:00:00 2001
From: erio
Date: Tue, 21 Apr 2026 00:27:10 +0800
Subject: [PATCH 098/261] feat(channels): add "Available Channels" aggregate
view
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Add a read-only aggregate view per channel: its linked groups and a
deterministic wildcard-free supported-model list with pricing details.
Backend
- service.Channel.SupportedModels(): combine ModelMapping keys with
same-platform ModelPricing.Models; trailing "*" keys expand via
pricing prefix match; platforms without a mapping produce no
entries (intentional "no mapping = not shown" rule).
- Extract splitWildcardSuffix() shared with toModelEntry.
- Build a per-call pricing lookup map (platform+lowerName -> *pricing)
to avoid O(N*M) scans in SupportedModels.
- ChannelService.ListAvailable() aggregates channels + active groups;
filters out group IDs no longer active.
- Admin route GET /api/v1/admin/channels/available returns the full
DTO (id, status, billing_model_source, restrict_models, groups,
supported_models).
- User route GET /api/v1/channels/available applies three filters:
Status==active, visible-group intersection, and platform filter
on supported_models (prevents cross-platform leak when a channel
links to both a user-accessible group and an inaccessible one on
another platform). Response is a plain array (matches the
/groups/available sibling shape). Field whitelist omits
billing_model_source, restrict_models, ids, status, sort_order.
Frontend
- New /admin/available-channels and /available-channels views backed
by a shared AvailableChannelsTable component (admin adds status +
billing-source columns via slots).
- PricingRow extracted to its own SFC; SupportedModelChip references
shared billing-mode constants in constants/channel.ts.
- Sidebar: new entry above "渠道管理" for admin; matching entry in
user nav.
- i18n: zh + en coverage for both namespaces.
Tests
- SupportedModels: wildcard-only pricing skipped, prefix-matches-
nothing, cross-platform bleed, case-insensitive dedup, empty
platform mapping.
- ListAvailable: nil groupRepo, inactive-group-ID dropped, stable
case-insensitive name sort.
- User handler: 401 on unauthenticated, visible-group intersection,
platform filter on supported_models, JSON whitelist.
- Admin handler: full DTO including default BillingModelSource
fallback.
Refs: issue #1729
---
backend/cmd/server/wire_gen.go | 8 +-
.../admin/available_channel_handler.go | 99 ++++++++
.../admin/available_channel_handler_test.go | 57 +++++
.../handler/available_channel_handler.go | 216 ++++++++++++++++++
.../handler/available_channel_handler_test.go | 121 ++++++++++
backend/internal/handler/handler.go | 32 +--
backend/internal/handler/wire.go | 36 +--
backend/internal/server/routes/admin.go | 1 +
backend/internal/server/routes/user.go | 6 +
backend/internal/service/channel.go | 172 ++++++++++++++
backend/internal/service/channel_available.go | 84 +++++++
.../service/channel_available_test.go | 119 ++++++++++
backend/internal/service/channel_service.go | 11 +-
.../internal/service/channel_service_test.go | 4 +-
backend/internal/service/channel_test.go | 204 +++++++++++++++++
.../service/model_pricing_resolver_test.go | 4 +-
frontend/src/api/admin/channels.ts | 39 +++-
frontend/src/api/channels.ts | 60 +++++
frontend/src/api/index.ts | 1 +
.../channels/AvailableChannelsTable.vue | 110 +++++++++
.../src/components/channels/PricingRow.vue | 29 +++
.../channels/SupportedModelChip.vue | 214 +++++++++++++++++
frontend/src/components/layout/AppSidebar.vue | 1 +
frontend/src/constants/channel.ts | 22 ++
frontend/src/i18n/locales/en.ts | 75 ++++++
frontend/src/i18n/locales/zh.ts | 75 ++++++
frontend/src/router/index.ts | 24 ++
.../src/views/admin/AvailableChannelsView.vue | 135 +++++++++++
.../src/views/user/AvailableChannelsView.vue | 98 ++++++++
29 files changed, 2012 insertions(+), 45 deletions(-)
create mode 100644 backend/internal/handler/admin/available_channel_handler.go
create mode 100644 backend/internal/handler/admin/available_channel_handler_test.go
create mode 100644 backend/internal/handler/available_channel_handler.go
create mode 100644 backend/internal/handler/available_channel_handler_test.go
create mode 100644 backend/internal/service/channel_available.go
create mode 100644 backend/internal/service/channel_available_test.go
create mode 100644 frontend/src/api/channels.ts
create mode 100644 frontend/src/components/channels/AvailableChannelsTable.vue
create mode 100644 frontend/src/components/channels/PricingRow.vue
create mode 100644 frontend/src/components/channels/SupportedModelChip.vue
create mode 100644 frontend/src/constants/channel.ts
create mode 100644 frontend/src/views/admin/AvailableChannelsView.vue
create mode 100644 frontend/src/views/user/AvailableChannelsView.vue
diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go
index 4e95035a..7568fa50 100644
--- a/backend/cmd/server/wire_gen.go
+++ b/backend/cmd/server/wire_gen.go
@@ -174,7 +174,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI)
digestSessionStore := service.NewDigestSessionStore()
channelRepository := repository.NewChannelRepository(db)
- channelService := service.NewChannelService(channelRepository, apiKeyAuthCacheInvalidator)
+ channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator)
+ availableChannelHandler := admin.NewAvailableChannelHandler(channelService)
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService)
@@ -234,7 +235,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService)
paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService)
- adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, paymentHandler)
+ availableChannelUserHandler := handler.NewAvailableChannelHandler(channelService, apiKeyService)
+ adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, availableChannelHandler, paymentHandler)
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
@@ -246,7 +248,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
paymentWebhookHandler := handler.NewPaymentWebhookHandler(paymentService, registry)
idempotencyCoordinator := service.ProvideIdempotencyCoordinator(idempotencyRepository, configConfig)
idempotencyCleanupService := service.ProvideIdempotencyCleanupService(idempotencyRepository, configConfig)
- handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, channelMonitorUserHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler, handlerPaymentHandler, paymentWebhookHandler, idempotencyCoordinator, idempotencyCleanupService)
+ handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, channelMonitorUserHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler, handlerPaymentHandler, paymentWebhookHandler, availableChannelUserHandler, idempotencyCoordinator, idempotencyCleanupService)
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
diff --git a/backend/internal/handler/admin/available_channel_handler.go b/backend/internal/handler/admin/available_channel_handler.go
new file mode 100644
index 00000000..53776105
--- /dev/null
+++ b/backend/internal/handler/admin/available_channel_handler.go
@@ -0,0 +1,99 @@
+package admin
+
+import (
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+
+ "github.com/gin-gonic/gin"
+)
+
+// AvailableChannelHandler 处理「可用渠道」聚合视图的管理员接口。
+//
+// 该视图以只读方式聚合渠道基础信息、关联分组与推导出的支持模型列表(无通配符)。
+type AvailableChannelHandler struct {
+ channelService *service.ChannelService
+}
+
+// NewAvailableChannelHandler 创建 AvailableChannelHandler 实例。
+func NewAvailableChannelHandler(channelService *service.ChannelService) *AvailableChannelHandler {
+ return &AvailableChannelHandler{channelService: channelService}
+}
+
+// availableGroupResponse 响应中的分组概要。
+type availableGroupResponse struct {
+ ID int64 `json:"id"`
+ Name string `json:"name"`
+ Platform string `json:"platform"`
+}
+
+// supportedModelResponse 响应中的支持模型条目。
+type supportedModelResponse struct {
+ Name string `json:"name"`
+ Platform string `json:"platform"`
+ Pricing *channelModelPricingResponse `json:"pricing"`
+}
+
+// availableChannelResponse 管理员视图完整字段集。
+type availableChannelResponse struct {
+ ID int64 `json:"id"`
+ Name string `json:"name"`
+ Description string `json:"description"`
+ Status string `json:"status"`
+ BillingModelSource string `json:"billing_model_source"`
+ RestrictModels bool `json:"restrict_models"`
+ Groups []availableGroupResponse `json:"groups"`
+ SupportedModels []supportedModelResponse `json:"supported_models"`
+}
+
+// AvailableChannelToAdminResponse 将 service 层的 AvailableChannel 转为管理员 DTO。
+// 导出供同 package 的复用;也用于构造测试 fixture。
+func AvailableChannelToAdminResponse(ch service.AvailableChannel) availableChannelResponse {
+ groups := make([]availableGroupResponse, 0, len(ch.Groups))
+ for _, g := range ch.Groups {
+ groups = append(groups, availableGroupResponse{ID: g.ID, Name: g.Name, Platform: g.Platform})
+ }
+ models := make([]supportedModelResponse, 0, len(ch.SupportedModels))
+ for i := range ch.SupportedModels {
+ m := ch.SupportedModels[i]
+ var pricing *channelModelPricingResponse
+ if m.Pricing != nil {
+ p := pricingToResponse(m.Pricing)
+ pricing = &p
+ }
+ models = append(models, supportedModelResponse{
+ Name: m.Name,
+ Platform: m.Platform,
+ Pricing: pricing,
+ })
+ }
+ billingSource := ch.BillingModelSource
+ if billingSource == "" {
+ billingSource = service.BillingModelSourceChannelMapped
+ }
+ return availableChannelResponse{
+ ID: ch.ID,
+ Name: ch.Name,
+ Description: ch.Description,
+ Status: ch.Status,
+ BillingModelSource: billingSource,
+ RestrictModels: ch.RestrictModels,
+ Groups: groups,
+ SupportedModels: models,
+ }
+}
+
+// List 列出所有可用渠道(管理员视图)。
+// GET /api/v1/admin/channels/available
+func (h *AvailableChannelHandler) List(c *gin.Context) {
+ channels, err := h.channelService.ListAvailable(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ out := make([]availableChannelResponse, 0, len(channels))
+ for _, ch := range channels {
+ out = append(out, AvailableChannelToAdminResponse(ch))
+ }
+ response.Success(c, gin.H{"items": out})
+}
diff --git a/backend/internal/handler/admin/available_channel_handler_test.go b/backend/internal/handler/admin/available_channel_handler_test.go
new file mode 100644
index 00000000..687e8dad
--- /dev/null
+++ b/backend/internal/handler/admin/available_channel_handler_test.go
@@ -0,0 +1,57 @@
+//go:build unit
+
+package admin
+
+import (
+ "encoding/json"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/stretchr/testify/require"
+)
+
+func TestAvailableChannelToAdminResponse_IncludesFullDTO(t *testing.T) {
+ // 管理员视图应包含 id / status / billing_model_source / restrict_models 等
+ // 管理字段;BillingModelSource 为空时应默认回填 channel_mapped。
+ input := service.AvailableChannel{
+ ID: 42,
+ Name: "ch",
+ Description: "d",
+ Status: service.StatusActive,
+ BillingModelSource: "", // 验证默认值填充
+ RestrictModels: true,
+ Groups: []service.AvailableGroupRef{
+ {ID: 1, Name: "g1", Platform: "anthropic"},
+ },
+ SupportedModels: []service.SupportedModel{
+ {Name: "claude-sonnet-4-6", Platform: "anthropic"},
+ },
+ }
+
+ resp := AvailableChannelToAdminResponse(input)
+ require.Equal(t, int64(42), resp.ID)
+ require.Equal(t, "ch", resp.Name)
+ require.Equal(t, service.StatusActive, resp.Status)
+ require.Equal(t, service.BillingModelSourceChannelMapped, resp.BillingModelSource)
+ require.True(t, resp.RestrictModels)
+ require.Len(t, resp.Groups, 1)
+ require.Len(t, resp.SupportedModels, 1)
+
+ // JSON 层验证管理字段确实会被序列化。
+ raw, err := json.Marshal(resp)
+ require.NoError(t, err)
+ var decoded map[string]any
+ require.NoError(t, json.Unmarshal(raw, &decoded))
+ for _, key := range []string{"id", "status", "billing_model_source", "restrict_models", "groups", "supported_models"} {
+ _, exists := decoded[key]
+ require.Truef(t, exists, "admin DTO must expose %q", key)
+ }
+}
+
+func TestAvailableChannelToAdminResponse_PreservesExplicitBillingSource(t *testing.T) {
+ input := service.AvailableChannel{
+ BillingModelSource: service.BillingModelSourceUpstream,
+ }
+ resp := AvailableChannelToAdminResponse(input)
+ require.Equal(t, service.BillingModelSourceUpstream, resp.BillingModelSource)
+}
diff --git a/backend/internal/handler/available_channel_handler.go b/backend/internal/handler/available_channel_handler.go
new file mode 100644
index 00000000..25452fc8
--- /dev/null
+++ b/backend/internal/handler/available_channel_handler.go
@@ -0,0 +1,216 @@
+package handler
+
+import (
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+
+ "github.com/gin-gonic/gin"
+)
+
+// AvailableChannelHandler 处理用户侧「可用渠道」查询。
+//
+// 用户侧接口委托 ChannelService.ListAvailable,并在返回前做三层过滤:
+// 1. 行过滤:只保留状态为 Active 且与当前用户可访问分组有交集的渠道;
+// 2. 分组过滤:渠道的 Groups 只保留用户可访问的那些;
+// 3. 平台过滤:渠道的 SupportedModels 只保留平台在用户可见 Groups 中出现过的模型,
+// 防止"渠道同时挂在 antigravity / anthropic 两个平台的分组上,用户只访问
+// antigravity,却看到 anthropic 模型"这类跨平台信息泄漏;
+// 4. 字段白名单:仅返回用户需要的字段(省略 BillingModelSource / RestrictModels
+// / 内部 ID / Status 等管理字段)。
+type AvailableChannelHandler struct {
+ channelService *service.ChannelService
+ apiKeyService *service.APIKeyService
+}
+
+// NewAvailableChannelHandler 创建用户侧可用渠道 handler。
+func NewAvailableChannelHandler(
+ channelService *service.ChannelService,
+ apiKeyService *service.APIKeyService,
+) *AvailableChannelHandler {
+ return &AvailableChannelHandler{
+ channelService: channelService,
+ apiKeyService: apiKeyService,
+ }
+}
+
+// userAvailableGroup 用户可见的分组概要(白名单字段)。
+type userAvailableGroup struct {
+ ID int64 `json:"id"`
+ Name string `json:"name"`
+ Platform string `json:"platform"`
+}
+
+// userSupportedModelPricing 用户可见的定价字段白名单。
+type userSupportedModelPricing struct {
+ BillingMode string `json:"billing_mode"`
+ InputPrice *float64 `json:"input_price"`
+ OutputPrice *float64 `json:"output_price"`
+ CacheWritePrice *float64 `json:"cache_write_price"`
+ CacheReadPrice *float64 `json:"cache_read_price"`
+ ImageOutputPrice *float64 `json:"image_output_price"`
+ PerRequestPrice *float64 `json:"per_request_price"`
+ Intervals []userPricingIntervalDTO `json:"intervals"`
+}
+
+// userPricingIntervalDTO 定价区间白名单(去掉内部 ID、SortOrder 等前端不渲染的字段)。
+type userPricingIntervalDTO struct {
+ MinTokens int `json:"min_tokens"`
+ MaxTokens *int `json:"max_tokens"`
+ TierLabel string `json:"tier_label,omitempty"`
+ InputPrice *float64 `json:"input_price"`
+ OutputPrice *float64 `json:"output_price"`
+ CacheWritePrice *float64 `json:"cache_write_price"`
+ CacheReadPrice *float64 `json:"cache_read_price"`
+ PerRequestPrice *float64 `json:"per_request_price"`
+}
+
+// userSupportedModel 用户可见的支持模型条目。
+type userSupportedModel struct {
+ Name string `json:"name"`
+ Platform string `json:"platform"`
+ Pricing *userSupportedModelPricing `json:"pricing"`
+}
+
+// userAvailableChannel 用户可见的渠道条目(白名单字段)。
+type userAvailableChannel struct {
+ Name string `json:"name"`
+ Description string `json:"description"`
+ Groups []userAvailableGroup `json:"groups"`
+ SupportedModels []userSupportedModel `json:"supported_models"`
+}
+
+// List 列出当前用户可见的「可用渠道」。
+// GET /api/v1/channels/available
+func (h *AvailableChannelHandler) List(c *gin.Context) {
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ response.Unauthorized(c, "User not authenticated")
+ return
+ }
+
+ userGroups, err := h.apiKeyService.GetAvailableGroups(c.Request.Context(), subject.UserID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ allowedGroupIDs := make(map[int64]struct{}, len(userGroups))
+ for i := range userGroups {
+ allowedGroupIDs[userGroups[i].ID] = struct{}{}
+ }
+
+ channels, err := h.channelService.ListAvailable(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ out := make([]userAvailableChannel, 0, len(channels))
+ for _, ch := range channels {
+ if ch.Status != service.StatusActive {
+ continue
+ }
+ visibleGroups := filterUserVisibleGroups(ch.Groups, allowedGroupIDs)
+ if len(visibleGroups) == 0 {
+ continue
+ }
+ allowedPlatforms := collectGroupPlatforms(visibleGroups)
+ out = append(out, userAvailableChannel{
+ Name: ch.Name,
+ Description: ch.Description,
+ Groups: visibleGroups,
+ SupportedModels: toUserSupportedModels(ch.SupportedModels, allowedPlatforms),
+ })
+ }
+
+ response.Success(c, out)
+}
+
+// collectGroupPlatforms 聚合 visible groups 覆盖的平台集合,用于过滤 SupportedModels。
+func collectGroupPlatforms(groups []userAvailableGroup) map[string]struct{} {
+ set := make(map[string]struct{}, len(groups))
+ for _, g := range groups {
+ if g.Platform == "" {
+ continue
+ }
+ set[g.Platform] = struct{}{}
+ }
+ return set
+}
+
+// filterUserVisibleGroups 仅保留用户可访问的分组。
+func filterUserVisibleGroups(
+ groups []service.AvailableGroupRef,
+ allowed map[int64]struct{},
+) []userAvailableGroup {
+ visible := make([]userAvailableGroup, 0, len(groups))
+ for _, g := range groups {
+ if _, ok := allowed[g.ID]; !ok {
+ continue
+ }
+ visible = append(visible, userAvailableGroup{
+ ID: g.ID,
+ Name: g.Name,
+ Platform: g.Platform,
+ })
+ }
+ return visible
+}
+
+// toUserSupportedModels 将 service 层支持模型转换为用户 DTO(字段白名单)。
+// 仅保留平台在 allowedPlatforms 中的条目,防止跨平台模型信息泄漏。
+// allowedPlatforms 为 nil 时不做平台过滤(保留全部,供测试或明确无过滤场景使用)。
+func toUserSupportedModels(
+ src []service.SupportedModel,
+ allowedPlatforms map[string]struct{},
+) []userSupportedModel {
+ out := make([]userSupportedModel, 0, len(src))
+ for i := range src {
+ m := src[i]
+ if allowedPlatforms != nil {
+ if _, ok := allowedPlatforms[m.Platform]; !ok {
+ continue
+ }
+ }
+ out = append(out, userSupportedModel{
+ Name: m.Name,
+ Platform: m.Platform,
+ Pricing: toUserPricing(m.Pricing),
+ })
+ }
+ return out
+}
+
+// toUserPricing 将 service 层定价转换为用户 DTO;入参为 nil 时返回 nil。
+func toUserPricing(p *service.ChannelModelPricing) *userSupportedModelPricing {
+ if p == nil {
+ return nil
+ }
+ intervals := make([]userPricingIntervalDTO, 0, len(p.Intervals))
+ for _, iv := range p.Intervals {
+ intervals = append(intervals, userPricingIntervalDTO{
+ MinTokens: iv.MinTokens,
+ MaxTokens: iv.MaxTokens,
+ TierLabel: iv.TierLabel,
+ InputPrice: iv.InputPrice,
+ OutputPrice: iv.OutputPrice,
+ CacheWritePrice: iv.CacheWritePrice,
+ CacheReadPrice: iv.CacheReadPrice,
+ PerRequestPrice: iv.PerRequestPrice,
+ })
+ }
+ billingMode := string(p.BillingMode)
+ if billingMode == "" {
+ billingMode = string(service.BillingModeToken)
+ }
+ return &userSupportedModelPricing{
+ BillingMode: billingMode,
+ InputPrice: p.InputPrice,
+ OutputPrice: p.OutputPrice,
+ CacheWritePrice: p.CacheWritePrice,
+ CacheReadPrice: p.CacheReadPrice,
+ ImageOutputPrice: p.ImageOutputPrice,
+ PerRequestPrice: p.PerRequestPrice,
+ Intervals: intervals,
+ }
+}
diff --git a/backend/internal/handler/available_channel_handler_test.go b/backend/internal/handler/available_channel_handler_test.go
new file mode 100644
index 00000000..cc2ca33a
--- /dev/null
+++ b/backend/internal/handler/available_channel_handler_test.go
@@ -0,0 +1,121 @@
+//go:build unit
+
+package handler
+
+import (
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/service"
+
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+func TestUserAvailableChannel_Unauthenticated401(t *testing.T) {
+ // 没有 AuthSubject 注入时,handler 应返回 401 且不触达 service 依赖。
+ gin.SetMode(gin.TestMode)
+ h := &AvailableChannelHandler{} // nil services — 401 路径不会调用它们
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/channels/available", nil)
+
+ h.List(c)
+
+ require.Equal(t, http.StatusUnauthorized, w.Code)
+}
+
+func TestFilterUserVisibleGroups_IntersectionOnly(t *testing.T) {
+ // 渠道挂在 {g1, g2, g3},用户只允许 {g1, g3} —— 响应必须仅含 g1/g3。
+ groups := []service.AvailableGroupRef{
+ {ID: 1, Name: "g1", Platform: "anthropic"},
+ {ID: 2, Name: "g2", Platform: "anthropic"},
+ {ID: 3, Name: "g3", Platform: "openai"},
+ }
+ allowed := map[int64]struct{}{1: {}, 3: {}}
+
+ visible := filterUserVisibleGroups(groups, allowed)
+ require.Len(t, visible, 2)
+ ids := []int64{visible[0].ID, visible[1].ID}
+ require.ElementsMatch(t, []int64{1, 3}, ids)
+}
+
+func TestCollectGroupPlatforms_DerivesAllowedSet(t *testing.T) {
+ groups := []userAvailableGroup{
+ {ID: 1, Platform: "anthropic"},
+ {ID: 2, Platform: "openai"},
+ {ID: 3, Platform: "anthropic"}, // 去重
+ {ID: 4, Platform: ""}, // 空平台忽略
+ }
+ got := collectGroupPlatforms(groups)
+ require.Len(t, got, 2)
+ _, hasAnt := got["anthropic"]
+ _, hasOA := got["openai"]
+ require.True(t, hasAnt)
+ require.True(t, hasOA)
+}
+
+func TestToUserSupportedModels_FiltersByAllowedPlatforms(t *testing.T) {
+ // 用户可访问分组只覆盖 anthropic;anthropic 平台的模型保留,openai 模型被剔除。
+ src := []service.SupportedModel{
+ {Name: "claude-sonnet-4-6", Platform: "anthropic", Pricing: nil},
+ {Name: "gpt-4o", Platform: "openai", Pricing: nil},
+ }
+ allowed := map[string]struct{}{"anthropic": {}}
+ out := toUserSupportedModels(src, allowed)
+ require.Len(t, out, 1)
+ require.Equal(t, "claude-sonnet-4-6", out[0].Name)
+}
+
+func TestToUserSupportedModels_NilAllowedPlatformsKeepsAll(t *testing.T) {
+ // 显式传 nil allowedPlatforms 表示不做过滤。
+ src := []service.SupportedModel{
+ {Name: "a", Platform: "anthropic"},
+ {Name: "b", Platform: "openai"},
+ }
+ require.Len(t, toUserSupportedModels(src, nil), 2)
+}
+
+func TestUserAvailableChannel_FieldWhitelist(t *testing.T) {
+ // 通过序列化 userAvailableChannel 结构体验证响应形状:
+ // 只有 name / description / groups / supported_models;不含管理端字段。
+ row := userAvailableChannel{
+ Name: "ch",
+ Description: "d",
+ Groups: []userAvailableGroup{{ID: 1, Name: "g1", Platform: "anthropic"}},
+ SupportedModels: []userSupportedModel{},
+ }
+ raw, err := json.Marshal(row)
+ require.NoError(t, err)
+ var decoded map[string]any
+ require.NoError(t, json.Unmarshal(raw, &decoded))
+
+ for _, key := range []string{"id", "status", "billing_model_source", "restrict_models"} {
+ _, exists := decoded[key]
+ require.Falsef(t, exists, "user DTO must not expose %q", key)
+ }
+ for _, key := range []string{"name", "description", "groups", "supported_models"} {
+ _, exists := decoded[key]
+ require.Truef(t, exists, "user DTO must expose %q", key)
+ }
+
+ // pricing interval 白名单:不应暴露 id / sort_order。
+ pricing := toUserPricing(&service.ChannelModelPricing{
+ BillingMode: service.BillingModeToken,
+ Intervals: []service.PricingInterval{
+ {ID: 7, MinTokens: 0, MaxTokens: nil, SortOrder: 3},
+ },
+ })
+ require.NotNil(t, pricing)
+ require.Len(t, pricing.Intervals, 1)
+ rawIv, err := json.Marshal(pricing.Intervals[0])
+ require.NoError(t, err)
+ var ivDecoded map[string]any
+ require.NoError(t, json.Unmarshal(rawIv, &ivDecoded))
+ for _, key := range []string{"id", "pricing_id", "sort_order"} {
+ _, exists := ivDecoded[key]
+ require.Falsef(t, exists, "user pricing interval must not expose %q", key)
+ }
+}
diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go
index bedb81ae..a35d8041 100644
--- a/backend/internal/handler/handler.go
+++ b/backend/internal/handler/handler.go
@@ -33,26 +33,28 @@ type AdminHandlers struct {
Channel *admin.ChannelHandler
ChannelMonitor *admin.ChannelMonitorHandler
ChannelMonitorTemplate *admin.ChannelMonitorRequestTemplateHandler
+ AvailableChannel *admin.AvailableChannelHandler
Payment *admin.PaymentHandler
}
// Handlers contains all HTTP handlers
type Handlers struct {
- Auth *AuthHandler
- User *UserHandler
- APIKey *APIKeyHandler
- Usage *UsageHandler
- Redeem *RedeemHandler
- Subscription *SubscriptionHandler
- Announcement *AnnouncementHandler
- ChannelMonitor *ChannelMonitorUserHandler
- Admin *AdminHandlers
- Gateway *GatewayHandler
- OpenAIGateway *OpenAIGatewayHandler
- Setting *SettingHandler
- Totp *TotpHandler
- Payment *PaymentHandler
- PaymentWebhook *PaymentWebhookHandler
+ Auth *AuthHandler
+ User *UserHandler
+ APIKey *APIKeyHandler
+ Usage *UsageHandler
+ Redeem *RedeemHandler
+ Subscription *SubscriptionHandler
+ Announcement *AnnouncementHandler
+ ChannelMonitor *ChannelMonitorUserHandler
+ Admin *AdminHandlers
+ Gateway *GatewayHandler
+ OpenAIGateway *OpenAIGatewayHandler
+ Setting *SettingHandler
+ Totp *TotpHandler
+ Payment *PaymentHandler
+ PaymentWebhook *PaymentWebhookHandler
+ AvailableChannel *AvailableChannelHandler
}
// BuildInfo contains build-time information
diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go
index 6584eb70..c9296b44 100644
--- a/backend/internal/handler/wire.go
+++ b/backend/internal/handler/wire.go
@@ -36,6 +36,7 @@ func ProvideAdminHandlers(
channelHandler *admin.ChannelHandler,
channelMonitorHandler *admin.ChannelMonitorHandler,
channelMonitorTemplateHandler *admin.ChannelMonitorRequestTemplateHandler,
+ availableChannelHandler *admin.AvailableChannelHandler,
paymentHandler *admin.PaymentHandler,
) *AdminHandlers {
return &AdminHandlers{
@@ -66,6 +67,7 @@ func ProvideAdminHandlers(
Channel: channelHandler,
ChannelMonitor: channelMonitorHandler,
ChannelMonitorTemplate: channelMonitorTemplateHandler,
+ AvailableChannel: availableChannelHandler,
Payment: paymentHandler,
}
}
@@ -97,25 +99,27 @@ func ProvideHandlers(
totpHandler *TotpHandler,
paymentHandler *PaymentHandler,
paymentWebhookHandler *PaymentWebhookHandler,
+ availableChannelHandler *AvailableChannelHandler,
_ *service.IdempotencyCoordinator,
_ *service.IdempotencyCleanupService,
) *Handlers {
return &Handlers{
- Auth: authHandler,
- User: userHandler,
- APIKey: apiKeyHandler,
- Usage: usageHandler,
- Redeem: redeemHandler,
- Subscription: subscriptionHandler,
- Announcement: announcementHandler,
- ChannelMonitor: channelMonitorUserHandler,
- Admin: adminHandlers,
- Gateway: gatewayHandler,
- OpenAIGateway: openaiGatewayHandler,
- Setting: settingHandler,
- Totp: totpHandler,
- Payment: paymentHandler,
- PaymentWebhook: paymentWebhookHandler,
+ Auth: authHandler,
+ User: userHandler,
+ APIKey: apiKeyHandler,
+ Usage: usageHandler,
+ Redeem: redeemHandler,
+ Subscription: subscriptionHandler,
+ Announcement: announcementHandler,
+ ChannelMonitor: channelMonitorUserHandler,
+ Admin: adminHandlers,
+ Gateway: gatewayHandler,
+ OpenAIGateway: openaiGatewayHandler,
+ Setting: settingHandler,
+ Totp: totpHandler,
+ Payment: paymentHandler,
+ PaymentWebhook: paymentWebhookHandler,
+ AvailableChannel: availableChannelHandler,
}
}
@@ -136,6 +140,7 @@ var ProviderSet = wire.NewSet(
ProvideSettingHandler,
NewPaymentHandler,
NewPaymentWebhookHandler,
+ NewAvailableChannelHandler,
// Admin handlers
admin.NewDashboardHandler,
@@ -165,6 +170,7 @@ var ProviderSet = wire.NewSet(
admin.NewChannelHandler,
admin.NewChannelMonitorHandler,
admin.NewChannelMonitorRequestTemplateHandler,
+ admin.NewAvailableChannelHandler,
admin.NewPaymentHandler,
// AdminHandlers and Handlers constructors
diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go
index 4b796d55..e4b5c548 100644
--- a/backend/internal/server/routes/admin.go
+++ b/backend/internal/server/routes/admin.go
@@ -560,6 +560,7 @@ func registerChannelRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
channels := admin.Group("/channels")
{
channels.GET("", h.Admin.Channel.List)
+ channels.GET("/available", h.Admin.AvailableChannel.List)
channels.GET("/model-pricing", h.Admin.Channel.GetModelDefaultPricing)
channels.GET("/:id", h.Admin.Channel.GetByID)
channels.POST("", h.Admin.Channel.Create)
diff --git a/backend/internal/server/routes/user.go b/backend/internal/server/routes/user.go
index 60503a5b..babab125 100644
--- a/backend/internal/server/routes/user.go
+++ b/backend/internal/server/routes/user.go
@@ -68,6 +68,12 @@ func RegisterUserRoutes(
groups.GET("/rates", h.APIKey.GetUserGroupRates)
}
+ // 用户可用渠道(非管理员接口)
+ channels := authenticated.Group("/channels")
+ {
+ channels.GET("/available", h.AvailableChannel.List)
+ }
+
// 使用记录
usage := authenticated.Group("/usage")
{
diff --git a/backend/internal/service/channel.go b/backend/internal/service/channel.go
index 93beb972..de31e829 100644
--- a/backend/internal/service/channel.go
+++ b/backend/internal/service/channel.go
@@ -345,3 +345,175 @@ type ChannelUsageFields struct {
BillingModelSource string // 计费模型来源:"requested" / "upstream" / "channel_mapped"
ModelMappingChain string // 映射链描述,如 "a→b→c"
}
+
+// SupportedModel 渠道的一个支持模型条目(无通配符、可直接展示给用户)
+type SupportedModel struct {
+ Name string // 用户侧模型名
+ Platform string // 所属平台
+ Pricing *ChannelModelPricing // 定价详情(nil 表示未配置定价)
+}
+
+// wildcardSuffix 是模型模式中的通配符后缀标记(仅支持尾部匹配)。
+const wildcardSuffix = "*"
+
+// splitWildcardSuffix 将模型模式拆分为 (prefix, isWildcard)。
+//
+// "claude-opus-*" → ("claude-opus-", true)
+// "claude-opus-4" → ("claude-opus-4", false)
+// "*" → ("", true)
+//
+// 注意:返回的 prefix 保持原始大小写,由调用方按需 ToLower。
+func splitWildcardSuffix(pattern string) (prefix string, isWildcard bool) {
+ if strings.HasSuffix(pattern, wildcardSuffix) {
+ return strings.TrimSuffix(pattern, wildcardSuffix), true
+ }
+ return pattern, false
+}
+
+// GetModelPricingByPlatform 在指定平台下查找精确模型的定价,未找到返回 nil。
+// 与 GetModelPricing 的区别:按 Platform 隔离,避免跨平台同名模型误匹配。
+func (c *Channel) GetModelPricingByPlatform(platform, model string) *ChannelModelPricing {
+ if c == nil {
+ return nil
+ }
+ modelLower := strings.ToLower(model)
+ for i := range c.ModelPricing {
+ if c.ModelPricing[i].Platform != platform {
+ continue
+ }
+ for _, m := range c.ModelPricing[i].Models {
+ if strings.ToLower(m) == modelLower {
+ cp := c.ModelPricing[i].Clone()
+ return &cp
+ }
+ }
+ }
+ return nil
+}
+
+// pricingLookup 是渠道定价在单个计算过程中的索引:platform → (lowerName → *pricing)。
+// 用于将 SupportedModels 的定价解析从 O(N*M) 降到 O(N+M)。
+type pricingLookup map[string]map[string]*ChannelModelPricing
+
+// buildPricingLookup 对渠道的定价列表做一次扫描,生成 platform+模型名 的索引。
+// 索引值是定价条目的 Clone 指针,调用方可安全按需返回副本而不污染缓存。
+// wildcard 后缀(如 "claude-*")不会被索引(它们不是精确模型名)。
+func buildPricingLookup(pricings []ChannelModelPricing) pricingLookup {
+ lookup := make(pricingLookup, len(pricings))
+ for i := range pricings {
+ p := pricings[i]
+ byModel, ok := lookup[p.Platform]
+ if !ok {
+ byModel = make(map[string]*ChannelModelPricing, len(p.Models))
+ lookup[p.Platform] = byModel
+ }
+ for _, m := range p.Models {
+ if _, wild := splitWildcardSuffix(m); wild {
+ continue
+ }
+ lower := strings.ToLower(m)
+ if _, exists := byModel[lower]; exists {
+ continue // 首个命中胜出(保持 case-insensitive 去重后第一个定价)
+ }
+ cp := pricings[i].Clone()
+ byModel[lower] = &cp
+ }
+ }
+ return lookup
+}
+
+// pricedNamesFor 返回指定平台下已索引的精确模型名(保留原始大小写,按添加顺序)。
+// 它是从 pricingLookup 中取 keys 并回查原始 ModelPricing 以得到原样字符串。
+func pricedNamesFor(pricings []ChannelModelPricing, platform string) []string {
+ seen := make(map[string]struct{})
+ out := make([]string, 0)
+ for i := range pricings {
+ if pricings[i].Platform != platform {
+ continue
+ }
+ for _, m := range pricings[i].Models {
+ if _, wild := splitWildcardSuffix(m); wild {
+ continue
+ }
+ lower := strings.ToLower(m)
+ if _, ok := seen[lower]; ok {
+ continue
+ }
+ seen[lower] = struct{}{}
+ out = append(out, m)
+ }
+ }
+ return out
+}
+
+// SupportedModels 计算渠道的支持模型列表,结果保证不含通配符。
+//
+// 算法(以渠道自身的 ModelMapping 为唯一入口):
+// - 遍历 Channel.ModelMapping 的每个 platform 条目;
+// - 映射 key 不带尾部 "*":直接作为一个支持模型名(即使没有匹配的定价行,也会产出 Pricing=nil 的条目);
+// - 映射 key 带尾部 "*":用同 platform 的 ModelPricing.Models 做前缀匹配展开(定价中带 "*" 的条目被忽略,因为它们本身就是模式,不是具体模型名);
+// - 未在 ModelMapping 中出现的 platform 不会产出任何条目——这是**刻意设计**("没配映射就不显示"),即使该平台有定价行。
+//
+// 每个结果尝试从 pricingLookup(平台+模型名索引)查找精确定价,未配置则 Pricing=nil。
+// 结果按 (Platform, Name) 稳定排序,并按 (Platform, lowercase(Name)) 去重。
+func (c *Channel) SupportedModels() []SupportedModel {
+ if c == nil || len(c.ModelMapping) == 0 {
+ return nil
+ }
+
+ lookup := buildPricingLookup(c.ModelPricing)
+
+ type dedupKey struct {
+ platform string
+ name string
+ }
+ seen := make(map[dedupKey]struct{})
+ result := make([]SupportedModel, 0)
+
+ add := func(platform, name string) {
+ key := dedupKey{platform: platform, name: strings.ToLower(name)}
+ if _, ok := seen[key]; ok {
+ return
+ }
+ seen[key] = struct{}{}
+ var pricing *ChannelModelPricing
+ if byModel, ok := lookup[platform]; ok {
+ if p, ok := byModel[strings.ToLower(name)]; ok {
+ pricing = p
+ }
+ }
+ result = append(result, SupportedModel{
+ Name: name,
+ Platform: platform,
+ Pricing: pricing,
+ })
+ }
+
+ for platform, mapping := range c.ModelMapping {
+ if len(mapping) == 0 {
+ continue
+ }
+ pricedNames := pricedNamesFor(c.ModelPricing, platform)
+ for src := range mapping {
+ prefix, isWild := splitWildcardSuffix(src)
+ if isWild {
+ prefixLower := strings.ToLower(prefix)
+ for _, candidate := range pricedNames {
+ if strings.HasPrefix(strings.ToLower(candidate), prefixLower) {
+ add(platform, candidate)
+ }
+ }
+ continue
+ }
+ add(platform, src)
+ }
+ }
+
+ sort.Slice(result, func(i, j int) bool {
+ if result[i].Platform != result[j].Platform {
+ return result[i].Platform < result[j].Platform
+ }
+ return result[i].Name < result[j].Name
+ })
+ return result
+}
diff --git a/backend/internal/service/channel_available.go b/backend/internal/service/channel_available.go
new file mode 100644
index 00000000..700380c2
--- /dev/null
+++ b/backend/internal/service/channel_available.go
@@ -0,0 +1,84 @@
+package service
+
+import (
+ "context"
+ "fmt"
+ "sort"
+ "strings"
+)
+
+// AvailableGroupRef 渠道视图中关联分组的简要信息。
+type AvailableGroupRef struct {
+ ID int64
+ Name string
+ Platform string
+}
+
+// AvailableChannel 可用渠道视图:用于「可用渠道」页面展示渠道基础信息 +
+// 关联的分组 + 推导出的支持模型列表(无通配符)。
+type AvailableChannel struct {
+ ID int64
+ Name string
+ Description string
+ Status string
+ BillingModelSource string
+ RestrictModels bool
+ Groups []AvailableGroupRef
+ SupportedModels []SupportedModel
+}
+
+// ListAvailable 返回所有渠道的可用视图:每个渠道附带关联分组信息与支持模型列表。
+//
+// 支持模型通过 (*Channel).SupportedModels() 计算得到(见 channel.go)。
+// 关联分组信息通过 groupRepo.ListActive 查询后按 ID 映射;渠道 GroupIDs 中未在活跃列表中
+// 的分组(已停用或删除)会被忽略。
+func (s *ChannelService) ListAvailable(ctx context.Context) ([]AvailableChannel, error) {
+ channels, err := s.repo.ListAll(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("list channels: %w", err)
+ }
+
+ groupByID := make(map[int64]AvailableGroupRef)
+ if s.groupRepo != nil {
+ groups, err := s.groupRepo.ListActive(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("list active groups: %w", err)
+ }
+ for i := range groups {
+ g := groups[i]
+ groupByID[g.ID] = AvailableGroupRef{
+ ID: g.ID,
+ Name: g.Name,
+ Platform: g.Platform,
+ }
+ }
+ }
+
+ out := make([]AvailableChannel, 0, len(channels))
+ for i := range channels {
+ ch := &channels[i]
+ groups := make([]AvailableGroupRef, 0, len(ch.GroupIDs))
+ for _, gid := range ch.GroupIDs {
+ if ref, ok := groupByID[gid]; ok {
+ groups = append(groups, ref)
+ }
+ }
+ sort.Slice(groups, func(i, j int) bool { return groups[i].Name < groups[j].Name })
+
+ out = append(out, AvailableChannel{
+ ID: ch.ID,
+ Name: ch.Name,
+ Description: ch.Description,
+ Status: ch.Status,
+ BillingModelSource: ch.BillingModelSource,
+ RestrictModels: ch.RestrictModels,
+ Groups: groups,
+ SupportedModels: ch.SupportedModels(),
+ })
+ }
+
+ sort.SliceStable(out, func(i, j int) bool {
+ return strings.ToLower(out[i].Name) < strings.ToLower(out[j].Name)
+ })
+ return out, nil
+}
diff --git a/backend/internal/service/channel_available_test.go b/backend/internal/service/channel_available_test.go
new file mode 100644
index 00000000..6a11fa4b
--- /dev/null
+++ b/backend/internal/service/channel_available_test.go
@@ -0,0 +1,119 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/stretchr/testify/require"
+)
+
+// stubGroupRepoForAvailable 是 ListAvailable 测试用的 GroupRepository stub,
+// 仅实现 ListActive;其他方法对本测试无关,返回零值即可。
+type stubGroupRepoForAvailable struct {
+ activeGroups []Group
+}
+
+func (s *stubGroupRepoForAvailable) ListActive(ctx context.Context) ([]Group, error) {
+ return s.activeGroups, nil
+}
+
+func (s *stubGroupRepoForAvailable) Create(ctx context.Context, group *Group) error { return nil }
+func (s *stubGroupRepoForAvailable) GetByID(ctx context.Context, id int64) (*Group, error) {
+ return nil, nil
+}
+func (s *stubGroupRepoForAvailable) GetByIDLite(ctx context.Context, id int64) (*Group, error) {
+ return nil, nil
+}
+func (s *stubGroupRepoForAvailable) Update(ctx context.Context, group *Group) error { return nil }
+func (s *stubGroupRepoForAvailable) Delete(ctx context.Context, id int64) error { return nil }
+func (s *stubGroupRepoForAvailable) DeleteCascade(ctx context.Context, id int64) ([]int64, error) {
+ return nil, nil
+}
+func (s *stubGroupRepoForAvailable) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
+ return nil, nil, nil
+}
+func (s *stubGroupRepoForAvailable) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
+ return nil, nil, nil
+}
+func (s *stubGroupRepoForAvailable) ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error) {
+ return nil, nil
+}
+func (s *stubGroupRepoForAvailable) ExistsByName(ctx context.Context, name string) (bool, error) {
+ return false, nil
+}
+func (s *stubGroupRepoForAvailable) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
+ return 0, 0, nil
+}
+func (s *stubGroupRepoForAvailable) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
+ return 0, nil
+}
+func (s *stubGroupRepoForAvailable) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
+ return nil, nil
+}
+func (s *stubGroupRepoForAvailable) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
+ return nil
+}
+func (s *stubGroupRepoForAvailable) UpdateSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error {
+ return nil
+}
+
+// newAvailableChannelService 构造一个 ChannelService,channelRepo.ListAll 返回给定 channels,
+// groupRepo 由参数决定(可传 nil 测试 nil 分支)。
+func newAvailableChannelService(channels []Channel, groupRepo GroupRepository) *ChannelService {
+ repo := &mockChannelRepository{
+ listAllFn: func(ctx context.Context) ([]Channel, error) { return channels, nil },
+ }
+ return NewChannelService(repo, groupRepo, nil)
+}
+
+func TestListAvailable_NilGroupRepo_NoGroupsAttached(t *testing.T) {
+ // groupRepo 为 nil 时不应 panic,且每个渠道的 Groups 应为空切片。
+ channels := []Channel{{
+ ID: 1,
+ Name: "chA",
+ Status: StatusActive,
+ GroupIDs: []int64{10, 20},
+ }}
+ svc := newAvailableChannelService(channels, nil)
+ out, err := svc.ListAvailable(context.Background())
+ require.NoError(t, err)
+ require.Len(t, out, 1)
+ require.Empty(t, out[0].Groups)
+}
+
+func TestListAvailable_InactiveGroupIDSilentlyDropped(t *testing.T) {
+ // 渠道 GroupIDs 中引用的 group 未出现在 ListActive 结果中(已停用或删除),应被静默丢弃。
+ channels := []Channel{{
+ ID: 1,
+ Name: "chA",
+ Status: StatusActive,
+ GroupIDs: []int64{1, 99},
+ }}
+ groupRepo := &stubGroupRepoForAvailable{
+ activeGroups: []Group{{ID: 1, Name: "g1", Platform: "anthropic"}},
+ }
+ svc := newAvailableChannelService(channels, groupRepo)
+ out, err := svc.ListAvailable(context.Background())
+ require.NoError(t, err)
+ require.Len(t, out, 1)
+ require.Len(t, out[0].Groups, 1)
+ require.Equal(t, int64(1), out[0].Groups[0].ID)
+}
+
+func TestListAvailable_SortedByName(t *testing.T) {
+ channels := []Channel{
+ {ID: 1, Name: "beta"},
+ {ID: 2, Name: "Alpha"},
+ {ID: 3, Name: "charlie"},
+ }
+ svc := newAvailableChannelService(channels, nil)
+ out, err := svc.ListAvailable(context.Background())
+ require.NoError(t, err)
+ require.Len(t, out, 3)
+ require.Equal(t, "Alpha", out[0].Name)
+ require.Equal(t, "beta", out[1].Name)
+ require.Equal(t, "charlie", out[2].Name)
+}
diff --git a/backend/internal/service/channel_service.go b/backend/internal/service/channel_service.go
index c29550d9..250df07b 100644
--- a/backend/internal/service/channel_service.go
+++ b/backend/internal/service/channel_service.go
@@ -141,6 +141,7 @@ const (
// ChannelService 渠道管理服务
type ChannelService struct {
repo ChannelRepository
+ groupRepo GroupRepository
authCacheInvalidator APIKeyAuthCacheInvalidator
cache atomic.Value // *channelCache
@@ -148,9 +149,10 @@ type ChannelService struct {
}
// NewChannelService 创建渠道服务实例
-func NewChannelService(repo ChannelRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *ChannelService {
+func NewChannelService(repo ChannelRepository, groupRepo GroupRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *ChannelService {
s := &ChannelService{
repo: repo,
+ groupRepo: groupRepo,
authCacheInvalidator: authCacheInvalidator,
}
return s
@@ -884,12 +886,7 @@ func conflictsBetween(a, b modelEntry) bool {
// toModelEntry 将模型名转换为 modelEntry
func toModelEntry(pattern string) modelEntry {
- lower := strings.ToLower(pattern)
- isWild := strings.HasSuffix(lower, "*")
- prefix := lower
- if isWild {
- prefix = strings.TrimSuffix(lower, "*")
- }
+ prefix, isWild := splitWildcardSuffix(strings.ToLower(pattern))
return modelEntry{pattern: pattern, prefix: prefix, wildcard: isWild}
}
diff --git a/backend/internal/service/channel_service_test.go b/backend/internal/service/channel_service_test.go
index e1345618..e44b882b 100644
--- a/backend/internal/service/channel_service_test.go
+++ b/backend/internal/service/channel_service_test.go
@@ -189,11 +189,11 @@ func (m *mockChannelAuthCacheInvalidator) InvalidateAuthCacheByGroupID(_ context
// ---------------------------------------------------------------------------
func newTestChannelService(repo *mockChannelRepository) *ChannelService {
- return NewChannelService(repo, nil)
+ return NewChannelService(repo, nil, nil)
}
func newTestChannelServiceWithAuth(repo *mockChannelRepository, auth *mockChannelAuthCacheInvalidator) *ChannelService {
- return NewChannelService(repo, auth)
+ return NewChannelService(repo, nil, auth)
}
// makeStandardRepo returns a repo that serves one active channel with anthropic pricing
diff --git a/backend/internal/service/channel_test.go b/backend/internal/service/channel_test.go
index deac64d6..812a3a63 100644
--- a/backend/internal/service/channel_test.go
+++ b/backend/internal/service/channel_test.go
@@ -433,3 +433,207 @@ func TestValidateIntervals_UnboundedNotLast(t *testing.T) {
require.Contains(t, err.Error(), "unbounded")
require.Contains(t, err.Error(), "last")
}
+
+func TestSupportedModels_ExactKeysAndPricing(t *testing.T) {
+ ch := &Channel{
+ ModelPricing: []ChannelModelPricing{
+ {ID: 10, Platform: "anthropic", Models: []string{"claude-sonnet-4-6"}, InputPrice: testPtrFloat64(3e-6)},
+ {ID: 11, Platform: "anthropic", Models: []string{"claude-opus-4-6"}, InputPrice: testPtrFloat64(1.5e-5)},
+ },
+ ModelMapping: map[string]map[string]string{
+ "anthropic": {
+ "claude-sonnet-4-6": "claude-sonnet-4-6",
+ "claude-opus-4-6": "claude-opus-4-6",
+ },
+ },
+ }
+
+ got := ch.SupportedModels()
+ require.Len(t, got, 2)
+ require.Equal(t, "anthropic", got[0].Platform)
+ require.Equal(t, "claude-opus-4-6", got[0].Name)
+ require.NotNil(t, got[0].Pricing)
+ require.Equal(t, int64(11), got[0].Pricing.ID)
+ require.Equal(t, "claude-sonnet-4-6", got[1].Name)
+ require.Equal(t, int64(10), got[1].Pricing.ID)
+}
+
+func TestSupportedModels_WildcardExpandedFromPricing(t *testing.T) {
+ ch := &Channel{
+ ModelPricing: []ChannelModelPricing{
+ {ID: 1, Platform: "anthropic", Models: []string{"claude-sonnet-4-6", "claude-sonnet-4-5"}},
+ {ID: 2, Platform: "anthropic", Models: []string{"claude-opus-4-6"}},
+ },
+ ModelMapping: map[string]map[string]string{
+ "anthropic": {
+ "claude-sonnet-*": "claude-sonnet-4-6",
+ },
+ },
+ }
+
+ got := ch.SupportedModels()
+ names := make([]string, 0, len(got))
+ for _, m := range got {
+ names = append(names, m.Name)
+ }
+ require.ElementsMatch(t, []string{"claude-sonnet-4-5", "claude-sonnet-4-6"}, names)
+ for _, m := range got {
+ require.NotContains(t, m.Name, "*")
+ }
+}
+
+func TestSupportedModels_PlatformWithoutMappingSkipped(t *testing.T) {
+ ch := &Channel{
+ ModelPricing: []ChannelModelPricing{
+ {ID: 1, Platform: "anthropic", Models: []string{"claude-sonnet-4-6"}},
+ {ID: 2, Platform: "openai", Models: []string{"gpt-4o"}},
+ },
+ ModelMapping: map[string]map[string]string{
+ "anthropic": {"claude-sonnet-4-6": "claude-sonnet-4-6"},
+ // openai 没有 mapping 条目
+ },
+ }
+
+ got := ch.SupportedModels()
+ require.Len(t, got, 1)
+ require.Equal(t, "anthropic", got[0].Platform)
+ require.Equal(t, "claude-sonnet-4-6", got[0].Name)
+}
+
+func TestSupportedModels_MissingPricingKeepsNilPricing(t *testing.T) {
+ ch := &Channel{
+ ModelMapping: map[string]map[string]string{
+ "anthropic": {"claude-sonnet-4-6": "claude-sonnet-4-6"},
+ },
+ }
+
+ got := ch.SupportedModels()
+ require.Len(t, got, 1)
+ require.Equal(t, "claude-sonnet-4-6", got[0].Name)
+ require.Nil(t, got[0].Pricing)
+}
+
+func TestSupportedModels_DedupAndSort(t *testing.T) {
+ ch := &Channel{
+ ModelPricing: []ChannelModelPricing{
+ {ID: 1, Platform: "anthropic", Models: []string{"claude-sonnet-4-6", "claude-sonnet-4-5"}},
+ {ID: 2, Platform: "openai", Models: []string{"gpt-4o"}},
+ },
+ ModelMapping: map[string]map[string]string{
+ "anthropic": {
+ "claude-sonnet-4-6": "upstream-a",
+ "claude-sonnet-*": "upstream-a",
+ },
+ "openai": {"gpt-4o": "gpt-4o"},
+ },
+ }
+
+ got := ch.SupportedModels()
+ require.Len(t, got, 3)
+ require.Equal(t, "anthropic", got[0].Platform)
+ require.Equal(t, "claude-sonnet-4-5", got[0].Name)
+ require.Equal(t, "anthropic", got[1].Platform)
+ require.Equal(t, "claude-sonnet-4-6", got[1].Name)
+ require.Equal(t, "openai", got[2].Platform)
+ require.Equal(t, "gpt-4o", got[2].Name)
+}
+
+func TestSupportedModels_NilChannelAndEmpty(t *testing.T) {
+ var nilCh *Channel
+ require.Nil(t, nilCh.SupportedModels())
+
+ empty := &Channel{}
+ require.Nil(t, empty.SupportedModels())
+}
+
+func TestGetModelPricingByPlatform(t *testing.T) {
+ ch := &Channel{
+ ModelPricing: []ChannelModelPricing{
+ {ID: 1, Platform: "anthropic", Models: []string{"claude-sonnet-4-6"}, InputPrice: testPtrFloat64(3e-6)},
+ {ID: 2, Platform: "openai", Models: []string{"claude-sonnet-4-6"}, InputPrice: testPtrFloat64(1e-6)},
+ },
+ }
+
+ ant := ch.GetModelPricingByPlatform("anthropic", "claude-sonnet-4-6")
+ require.NotNil(t, ant)
+ require.Equal(t, int64(1), ant.ID)
+
+ oa := ch.GetModelPricingByPlatform("openai", "claude-sonnet-4-6")
+ require.NotNil(t, oa)
+ require.Equal(t, int64(2), oa.ID)
+
+ require.Nil(t, ch.GetModelPricingByPlatform("gemini", "claude-sonnet-4-6"))
+}
+
+func TestSupportedModels_WildcardOnlyPricingRowsSkipped(t *testing.T) {
+ // 定价中含通配符条目(pattern),不应被当作具体模型名展开。
+ ch := &Channel{
+ ModelPricing: []ChannelModelPricing{
+ {ID: 1, Platform: "anthropic", Models: []string{"claude-sonnet-*", "claude-sonnet-4-6"}},
+ },
+ ModelMapping: map[string]map[string]string{
+ "anthropic": {"claude-sonnet-*": "claude-sonnet-4-6"},
+ },
+ }
+ got := ch.SupportedModels()
+ require.Len(t, got, 1)
+ require.Equal(t, "claude-sonnet-4-6", got[0].Name)
+ for _, m := range got {
+ require.NotContains(t, m.Name, "*")
+ }
+}
+
+func TestSupportedModels_WildcardPrefixMatchesNothing(t *testing.T) {
+ // 通配符模式无任何对应定价模型时,该平台应产出 0 个模型。
+ ch := &Channel{
+ ModelPricing: []ChannelModelPricing{
+ {ID: 1, Platform: "openai", Models: []string{"gpt-4o"}},
+ },
+ ModelMapping: map[string]map[string]string{
+ "anthropic": {"gpt-foo-*": "gpt-foo-1"},
+ },
+ }
+ require.Empty(t, ch.SupportedModels())
+}
+
+func TestSupportedModels_CrossPlatformPricingDoesNotBleed(t *testing.T) {
+ // anthropic 的通配符不应拉入 openai 定价行,哪怕名字恰好前缀匹配。
+ ch := &Channel{
+ ModelPricing: []ChannelModelPricing{
+ {ID: 1, Platform: "openai", Models: []string{"claude-sonnet-4-6"}},
+ },
+ ModelMapping: map[string]map[string]string{
+ "anthropic": {"claude-sonnet-*": "x"},
+ },
+ }
+ require.Empty(t, ch.SupportedModels())
+}
+
+func TestSupportedModels_CaseInsensitiveDedup(t *testing.T) {
+ // 两行定价用不同大小写定义了同一模型,结果应去重为 1 条;首次出现的原始大小写保留。
+ ch := &Channel{
+ ModelPricing: []ChannelModelPricing{
+ {ID: 1, Platform: "openai", Models: []string{"GPT-4o"}},
+ {ID: 2, Platform: "openai", Models: []string{"gpt-4o"}},
+ },
+ ModelMapping: map[string]map[string]string{
+ "openai": {"gpt-*": "x"},
+ },
+ }
+ got := ch.SupportedModels()
+ require.Len(t, got, 1)
+ require.Equal(t, "GPT-4o", got[0].Name)
+}
+
+func TestSupportedModels_EmptyPlatformMapping(t *testing.T) {
+ // ModelMapping 有一个 platform key 但 value 是空 map —— 该 platform 应被跳过。
+ ch := &Channel{
+ ModelPricing: []ChannelModelPricing{
+ {ID: 1, Platform: "anthropic", Models: []string{"claude-sonnet-4-6"}},
+ },
+ ModelMapping: map[string]map[string]string{
+ "anthropic": {},
+ },
+ }
+ require.Empty(t, ch.SupportedModels())
+}
diff --git a/backend/internal/service/model_pricing_resolver_test.go b/backend/internal/service/model_pricing_resolver_test.go
index 905c4df6..7484eed5 100644
--- a/backend/internal/service/model_pricing_resolver_test.go
+++ b/backend/internal/service/model_pricing_resolver_test.go
@@ -184,7 +184,7 @@ func newResolverWithChannel(t *testing.T, pricing []ChannelModelPricing) *ModelP
return map[int64]string{groupID: "anthropic"}, nil
},
}
- cs := NewChannelService(repo, nil)
+ cs := NewChannelService(repo, nil, nil)
bs := newTestBillingServiceForResolver()
return NewModelPricingResolver(cs, bs)
}
@@ -517,7 +517,7 @@ func TestResolve_WithChannelOverride_CacheError(t *testing.T) {
return nil, errors.New("database unavailable")
},
}
- cs := NewChannelService(repo, nil)
+ cs := NewChannelService(repo, nil, nil)
bs := newTestBillingServiceForResolver()
r := NewModelPricingResolver(cs, bs)
diff --git a/frontend/src/api/admin/channels.ts b/frontend/src/api/admin/channels.ts
index f129ceaa..eb7e91d8 100644
--- a/frontend/src/api/admin/channels.ts
+++ b/frontend/src/api/admin/channels.ts
@@ -163,5 +163,42 @@ export async function getModelDefaultPricing(model: string): Promise {
+ const { data } = await apiClient.get('/admin/channels/available', {
+ signal: options?.signal
+ })
+ return data.items
+}
+
+const channelsAPI = { list, getById, create, update, remove, getModelDefaultPricing, listAvailable }
export default channelsAPI
diff --git a/frontend/src/api/channels.ts b/frontend/src/api/channels.ts
new file mode 100644
index 00000000..98b890df
--- /dev/null
+++ b/frontend/src/api/channels.ts
@@ -0,0 +1,60 @@
+/**
+ * User Channels API endpoints (non-admin)
+ * 用户侧「可用渠道」聚合查询:渠道 + 用户可访问的分组 + 支持模型(含定价)。
+ */
+
+import { apiClient } from './client'
+import type { BillingMode } from '@/constants/channel'
+
+export interface UserAvailableGroup {
+ id: number
+ name: string
+ platform: string
+}
+
+export interface UserPricingInterval {
+ min_tokens: number
+ max_tokens: number | null
+ tier_label?: string
+ input_price: number | null
+ output_price: number | null
+ cache_write_price: number | null
+ cache_read_price: number | null
+ per_request_price: number | null
+}
+
+export interface UserSupportedModelPricing {
+ billing_mode: BillingMode
+ input_price: number | null
+ output_price: number | null
+ cache_write_price: number | null
+ cache_read_price: number | null
+ image_output_price: number | null
+ per_request_price: number | null
+ intervals: UserPricingInterval[]
+}
+
+export interface UserSupportedModel {
+ name: string
+ platform: string
+ pricing: UserSupportedModelPricing | null
+}
+
+export interface UserAvailableChannel {
+ name: string
+ description: string
+ groups: UserAvailableGroup[]
+ supported_models: UserSupportedModel[]
+}
+
+/** 列出当前用户可见的「可用渠道」(与 /groups/available 保持一致,返回平数组)。 */
+export async function getAvailable(options?: { signal?: AbortSignal }): Promise {
+ const { data } = await apiClient.get('/channels/available', {
+ signal: options?.signal
+ })
+ return data
+}
+
+export const userChannelsAPI = { getAvailable }
+
+export default userChannelsAPI
diff --git a/frontend/src/api/index.ts b/frontend/src/api/index.ts
index dd005a0d..6702468d 100644
--- a/frontend/src/api/index.ts
+++ b/frontend/src/api/index.ts
@@ -16,6 +16,7 @@ export { userAPI } from './user'
export { redeemAPI, type RedeemHistoryItem } from './redeem'
export { paymentAPI } from './payment'
export { userGroupsAPI } from './groups'
+export { userChannelsAPI } from './channels'
export { totpAPI } from './totp'
export { default as announcementsAPI } from './announcements'
export { channelMonitorUserAPI } from './channelMonitor'
diff --git a/frontend/src/components/channels/AvailableChannelsTable.vue b/frontend/src/components/channels/AvailableChannelsTable.vue
new file mode 100644
index 00000000..403391a3
--- /dev/null
+++ b/frontend/src/components/channels/AvailableChannelsTable.vue
@@ -0,0 +1,110 @@
+
+
+
+ {{ row.name }}
+
+ {{ row.description }}
+
+
+
+
+
+ -
+
+
+
+ {{ g.name }}
+
+
+
+
+
+
+ {{ noModelsLabel }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/frontend/src/components/channels/PricingRow.vue b/frontend/src/components/channels/PricingRow.vue
new file mode 100644
index 00000000..8db077c0
--- /dev/null
+++ b/frontend/src/components/channels/PricingRow.vue
@@ -0,0 +1,29 @@
+
+
+ {{ label }}
+ {{ display }}
+
+
+
+
diff --git a/frontend/src/components/channels/SupportedModelChip.vue b/frontend/src/components/channels/SupportedModelChip.vue
new file mode 100644
index 00000000..82f27607
--- /dev/null
+++ b/frontend/src/components/channels/SupportedModelChip.vue
@@ -0,0 +1,214 @@
+
+
+
+
+ {{ model.platform }}
+
+ {{ model.name }}
+
+
+
+
+
+ {{ model.name }}
+
+ {{ model.platform }}
+
+
+
+
+ {{ noPricingLabel }}
+
+
+
+
+ {{ t(prefixKey('billingMode')) }}
+ {{ billingModeLabel }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {{ t(prefixKey('intervals')) }}
+
+
+
+
+ {{ iv.tier_label }}
+ {{ formatRange(iv.min_tokens, iv.max_tokens) }}
+
+ {{ formatInterval(iv, model.pricing.billing_mode) }}
+
+
+
+
+
+
+
+
+
+
diff --git a/frontend/src/components/layout/AppSidebar.vue b/frontend/src/components/layout/AppSidebar.vue
index 248e0021..25284276 100644
--- a/frontend/src/components/layout/AppSidebar.vue
+++ b/frontend/src/components/layout/AppSidebar.vue
@@ -648,6 +648,7 @@ function buildSelfNavItems(withDashboard: boolean): NavItem[] {
{ path: '/subscriptions', label: t('nav.mySubscriptions'), icon: CreditCardIcon, hideInSimpleMode: true },
{ path: '/purchase', label: t('nav.buySubscription'), icon: RechargeSubscriptionIcon, hideInSimpleMode: true, featureFlag: flagPayment },
{ path: '/orders', label: t('nav.myOrders'), icon: OrderListIcon, hideInSimpleMode: true, featureFlag: flagPayment },
+ { path: '/available-channels', label: t('nav.availableChannels'), icon: ChannelIcon, hideInSimpleMode: true },
{ path: '/redeem', label: t('nav.redeem'), icon: GiftIcon, hideInSimpleMode: true },
{ path: '/profile', label: t('nav.profile'), icon: UserIcon },
...customMenuItemsForUser.value.map((item): NavItem => ({
diff --git a/frontend/src/constants/channel.ts b/frontend/src/constants/channel.ts
new file mode 100644
index 00000000..c08f4800
--- /dev/null
+++ b/frontend/src/constants/channel.ts
@@ -0,0 +1,22 @@
+/** Channel status values (must match service.Status* constants in Go). */
+export const CHANNEL_STATUS_ACTIVE = 'active' as const
+export const CHANNEL_STATUS_DISABLED = 'disabled' as const
+export type ChannelStatus = typeof CHANNEL_STATUS_ACTIVE | typeof CHANNEL_STATUS_DISABLED
+
+/** Billing mode values (must match service.BillingMode* constants in Go). */
+export const BILLING_MODE_TOKEN = 'token' as const
+export const BILLING_MODE_PER_REQUEST = 'per_request' as const
+export const BILLING_MODE_IMAGE = 'image' as const
+export type BillingMode =
+ | typeof BILLING_MODE_TOKEN
+ | typeof BILLING_MODE_PER_REQUEST
+ | typeof BILLING_MODE_IMAGE
+
+/** Billing-model-source values (must match service.BillingModelSource* constants in Go). */
+export const BILLING_MODEL_SOURCE_REQUESTED = 'requested' as const
+export const BILLING_MODEL_SOURCE_UPSTREAM = 'upstream' as const
+export const BILLING_MODEL_SOURCE_CHANNEL_MAPPED = 'channel_mapped' as const
+export type BillingModelSource =
+ | typeof BILLING_MODEL_SOURCE_REQUESTED
+ | typeof BILLING_MODEL_SOURCE_UPSTREAM
+ | typeof BILLING_MODEL_SOURCE_CHANNEL_MAPPED
diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts
index eb401ae2..a54639cc 100644
--- a/frontend/src/i18n/locales/en.ts
+++ b/frontend/src/i18n/locales/en.ts
@@ -344,6 +344,7 @@ export default {
users: 'Users',
groups: 'Groups',
channels: 'Channels',
+ availableChannels: 'Available Channels',
subscriptions: 'Subscriptions',
accounts: 'Accounts',
proxies: 'Proxies',
@@ -929,6 +930,38 @@ export default {
}
},
+ // Available Channels (user-facing)
+ availableChannels: {
+ title: 'Available Channels',
+ description: 'Channels you can access, along with their supported models and pricing',
+ searchPlaceholder: 'Search channels or models...',
+ empty: 'No available channels',
+ noModels: 'No models configured',
+ noPricing: 'Pricing not configured',
+ columns: {
+ name: 'Channel',
+ groups: 'Your Accessible Groups',
+ supportedModels: 'Supported Models'
+ },
+ pricing: {
+ billingMode: 'Billing Mode',
+ billingModeToken: 'Per Token',
+ billingModePerRequest: 'Per Request',
+ billingModeImage: 'Per Image',
+ inputPrice: 'Input',
+ outputPrice: 'Output',
+ cacheWritePrice: 'Cache Write',
+ cacheReadPrice: 'Cache Read',
+ imageOutputPrice: 'Image Output',
+ perRequestPrice: 'Per Request',
+ intervals: 'Tiered Pricing',
+ tierLabel: 'Tier',
+ tokenRange: 'Token Range',
+ unitPerMillion: '/ 1M tokens',
+ unitPerRequest: '/ request'
+ }
+ },
+
// Redeem
redeem: {
title: 'Redeem Code',
@@ -1980,6 +2013,48 @@ export default {
}
},
+ // Available Channels (aggregated read-only view)
+ availableChannels: {
+ title: 'Available Channels',
+ description: 'Aggregated view: each channel with its linked groups and supported models (wildcards expanded)',
+ searchPlaceholder: 'Search channels or models...',
+ columns: {
+ name: 'Channel',
+ status: 'Status',
+ billingSource: 'Billing Model Source',
+ groups: 'Linked Groups',
+ supportedModels: 'Supported Models'
+ },
+ empty: 'No data',
+ noGroups: 'No linked groups',
+ noModels: 'No model mapping configured',
+ noPricing: 'Pricing not configured',
+ statusActive: 'Active',
+ statusDisabled: 'Disabled',
+ billingSource: {
+ requested: 'Requested model',
+ upstream: 'Upstream model',
+ channel_mapped: 'Channel-mapped model'
+ },
+ pricing: {
+ billingMode: 'Billing Mode',
+ billingModeToken: 'Per Token',
+ billingModePerRequest: 'Per Request',
+ billingModeImage: 'Per Image',
+ inputPrice: 'Input',
+ outputPrice: 'Output',
+ cacheWritePrice: 'Cache Write',
+ cacheReadPrice: 'Cache Read',
+ imageOutputPrice: 'Image Output',
+ perRequestPrice: 'Per Request',
+ intervals: 'Tiered Pricing',
+ tierLabel: 'Tier',
+ tokenRange: 'Token Range',
+ unitPerMillion: '/ 1M tokens',
+ unitPerRequest: '/ request'
+ }
+ },
+
// Channel Management
channels: {
title: 'Channel Management',
diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts
index d38b5034..e69b0223 100644
--- a/frontend/src/i18n/locales/zh.ts
+++ b/frontend/src/i18n/locales/zh.ts
@@ -344,6 +344,7 @@ export default {
users: '用户管理',
groups: '分组管理',
channels: '渠道管理',
+ availableChannels: '可用渠道',
subscriptions: '订阅管理',
accounts: '账号管理',
proxies: 'IP管理',
@@ -933,6 +934,38 @@ export default {
}
},
+ // Available Channels (user-facing)
+ availableChannels: {
+ title: '可用渠道',
+ description: '查看您可访问的渠道与其支持的模型、定价',
+ searchPlaceholder: '搜索渠道或模型...',
+ empty: '暂无可用渠道',
+ noModels: '未配置模型',
+ noPricing: '未配置定价',
+ columns: {
+ name: '渠道名',
+ groups: '我可访问的分组',
+ supportedModels: '支持模型'
+ },
+ pricing: {
+ billingMode: '计费模式',
+ billingModeToken: '按 Token',
+ billingModePerRequest: '按次',
+ billingModeImage: '按图片',
+ inputPrice: '输入',
+ outputPrice: '输出',
+ cacheWritePrice: '缓存写入',
+ cacheReadPrice: '缓存读取',
+ imageOutputPrice: '图片输出',
+ perRequestPrice: '每次请求',
+ intervals: '阶梯定价',
+ tierLabel: '层级',
+ tokenRange: 'Token 区间',
+ unitPerMillion: '/ 1M token',
+ unitPerRequest: '/ 次'
+ }
+ },
+
// Redeem
redeem: {
title: '兑换码',
@@ -2059,6 +2092,48 @@ export default {
}
},
+ // Available Channels (aggregated read-only view)
+ availableChannels: {
+ title: '可用渠道',
+ description: '按渠道聚合查看关联分组与支持模型(已展开通配符)',
+ searchPlaceholder: '搜索渠道或模型...',
+ columns: {
+ name: '渠道名',
+ status: '状态',
+ billingSource: '计费模型来源',
+ groups: '关联分组',
+ supportedModels: '支持模型'
+ },
+ empty: '暂无数据',
+ noGroups: '未关联分组',
+ noModels: '未配置模型映射',
+ noPricing: '未配置定价',
+ statusActive: '启用',
+ statusDisabled: '停用',
+ billingSource: {
+ requested: '请求模型',
+ upstream: '上游模型',
+ channel_mapped: '映射后模型'
+ },
+ pricing: {
+ billingMode: '计费模式',
+ billingModeToken: '按 Token',
+ billingModePerRequest: '按次',
+ billingModeImage: '按图片',
+ inputPrice: '输入',
+ outputPrice: '输出',
+ cacheWritePrice: '缓存写入',
+ cacheReadPrice: '缓存读取',
+ imageOutputPrice: '图片输出',
+ perRequestPrice: '每次请求',
+ intervals: '阶梯定价',
+ tierLabel: '层级',
+ tokenRange: 'Token 区间',
+ unitPerMillion: '/ 1M token',
+ unitPerRequest: '/ 次'
+ }
+ },
+
// Channel Management
channels: {
title: '渠道管理',
diff --git a/frontend/src/router/index.ts b/frontend/src/router/index.ts
index 491a984d..567876b6 100644
--- a/frontend/src/router/index.ts
+++ b/frontend/src/router/index.ts
@@ -197,6 +197,18 @@ const routes: RouteRecordRaw[] = [
descriptionKey: 'redeem.description'
}
},
+ {
+ path: '/available-channels',
+ name: 'UserAvailableChannels',
+ component: () => import('@/views/user/AvailableChannelsView.vue'),
+ meta: {
+ requiresAuth: true,
+ requiresAdmin: false,
+ title: 'Available Channels',
+ titleKey: 'availableChannels.title',
+ descriptionKey: 'availableChannels.description'
+ }
+ },
{
path: '/profile',
name: 'Profile',
@@ -358,6 +370,18 @@ const routes: RouteRecordRaw[] = [
descriptionKey: 'admin.groups.description'
}
},
+ {
+ path: '/admin/available-channels',
+ name: 'AdminAvailableChannels',
+ component: () => import('@/views/admin/AvailableChannelsView.vue'),
+ meta: {
+ requiresAuth: true,
+ requiresAdmin: true,
+ title: 'Available Channels',
+ titleKey: 'admin.availableChannels.title',
+ descriptionKey: 'admin.availableChannels.description'
+ }
+ },
{
path: '/admin/channels',
redirect: '/admin/channels/pricing'
diff --git a/frontend/src/views/admin/AvailableChannelsView.vue b/frontend/src/views/admin/AvailableChannelsView.vue
new file mode 100644
index 00000000..3f0ee436
--- /dev/null
+++ b/frontend/src/views/admin/AvailableChannelsView.vue
@@ -0,0 +1,135 @@
+
+
+
+
+
+
+
+
+
+ {{ t('admin.availableChannels.noGroups') }}
+
+
+
+ {{ statusLabel(row.status) }}
+
+
+
+
+
+ {{
+ t(
+ `admin.availableChannels.billingSource.${row.billing_model_source || BILLING_MODEL_SOURCE_CHANNEL_MAPPED}`
+ )
+ }}
+
+
+
+
+
+
+
+
+
diff --git a/frontend/src/views/user/AvailableChannelsView.vue b/frontend/src/views/user/AvailableChannelsView.vue
new file mode 100644
index 00000000..44ee456e
--- /dev/null
+++ b/frontend/src/views/user/AvailableChannelsView.vue
@@ -0,0 +1,98 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
--
GitLab
From 1521d503990f2b3ab6d474958b64c1e4f5fb3baf Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 00:31:52 +0800
Subject: [PATCH 099/261] fix: apply email first-bind defaults on legacy login
---
backend/internal/service/auth_service.go | 80 ++++++--
.../auth_service_identity_sync_test.go | 189 +++++++++++++++++-
2 files changed, 238 insertions(+), 31 deletions(-)
diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go
index dda6df04..d0d5e4e3 100644
--- a/backend/internal/service/auth_service.go
+++ b/backend/internal/service/auth_service.go
@@ -807,37 +807,75 @@ func (s *AuthService) backfillEmailIdentityOnSuccessfulLogin(ctx context.Context
if s == nil || user == nil || user.ID <= 0 {
return
}
- s.ensureEmailAuthIdentity(ctx, user)
+ if s.ensureEmailAuthIdentity(ctx, user) {
+ if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, user.ID, "email"); err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to apply email first bind defaults: user_id=%d err=%v", user.ID, err)
+ }
+ }
}
-func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) {
+func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) bool {
if s == nil || s.entClient == nil || user == nil || user.ID <= 0 {
- return
+ return false
}
email := strings.ToLower(strings.TrimSpace(user.Email))
if email == "" || isReservedEmail(email) {
- return
+ return false
}
- if err := s.entClient.AuthIdentity.Create().
- SetUserID(user.ID).
- SetProviderType("email").
- SetProviderKey("email").
- SetProviderSubject(email).
- SetVerifiedAt(time.Now().UTC()).
- SetMetadata(map[string]any{
- "source": "auth_service_dual_write",
- }).
- OnConflictColumns(
- authidentity.FieldProviderType,
- authidentity.FieldProviderKey,
- authidentity.FieldProviderSubject,
- ).
- DoNothing().
- Exec(ctx); err != nil {
- logger.LegacyPrintf("service.auth", "[Auth] Failed to ensure email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
+ client := s.entClient
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ client = tx.Client()
+ }
+
+ buildQuery := func() *dbent.AuthIdentityQuery {
+ return client.AuthIdentity.Query().Where(
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ(email),
+ )
}
+
+ existed, err := buildQuery().Exist(ctx)
+ if err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to inspect email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
+ return false
+ }
+
+ if !existed {
+ if err := client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("email").
+ SetProviderKey("email").
+ SetProviderSubject(email).
+ SetVerifiedAt(time.Now().UTC()).
+ SetMetadata(map[string]any{
+ "source": "auth_service_dual_write",
+ }).
+ OnConflictColumns(
+ authidentity.FieldProviderType,
+ authidentity.FieldProviderKey,
+ authidentity.FieldProviderSubject,
+ ).
+ DoNothing().
+ Exec(ctx); err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to ensure email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
+ return false
+ }
+ }
+
+ identity, err := buildQuery().Only(ctx)
+ if err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to reload email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
+ return false
+ }
+ if identity.UserID != user.ID {
+ logger.LegacyPrintf("service.auth", "[Auth] Email auth identity ownership mismatch: user_id=%d email=%s owner_id=%d", user.ID, email, identity.UserID)
+ return false
+ }
+
+ return !existed
}
func inferLegacySignupSource(email string) string {
diff --git a/backend/internal/service/auth_service_identity_sync_test.go b/backend/internal/service/auth_service_identity_sync_test.go
index fcb4813b..e2a94b13 100644
--- a/backend/internal/service/auth_service_identity_sync_test.go
+++ b/backend/internal/service/auth_service_identity_sync_test.go
@@ -21,6 +21,19 @@ import (
_ "modernc.org/sqlite"
)
+type authIdentityDefaultSubAssignerStub struct {
+ calls []*service.AssignSubscriptionInput
+}
+
+func (s *authIdentityDefaultSubAssignerStub) AssignOrExtendSubscription(
+ _ context.Context,
+ input *service.AssignSubscriptionInput,
+) (*service.UserSubscription, bool, error) {
+ cloned := *input
+ s.calls = append(s.calls, &cloned)
+ return &service.UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, true, nil
+}
+
type authIdentitySettingRepoStub struct {
values map[string]string
}
@@ -40,8 +53,14 @@ func (s *authIdentitySettingRepoStub) Set(context.Context, string, string) error
panic("unexpected Set call")
}
-func (s *authIdentitySettingRepoStub) GetMultiple(context.Context, []string) (map[string]string, error) {
- panic("unexpected GetMultiple call")
+func (s *authIdentitySettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
+ out := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if v, ok := s.values[key]; ok {
+ out[key] = v
+ }
+ }
+ return out, nil
}
func (s *authIdentitySettingRepoStub) SetMultiple(context.Context, map[string]string) error {
@@ -56,7 +75,11 @@ func (s *authIdentitySettingRepoStub) Delete(context.Context, string) error {
panic("unexpected Delete call")
}
-func newAuthServiceWithEnt(t *testing.T) (*service.AuthService, service.UserRepository, *dbent.Client) {
+func newAuthServiceWithEnt(
+ t *testing.T,
+ settings map[string]string,
+ defaultSubAssigner service.DefaultSubscriptionAssigner,
+) (*service.AuthService, service.UserRepository, *dbent.Client) {
t.Helper()
db, err := sql.Open("sqlite", "file:auth_service_identity_sync?mode=memory&cache=shared")
@@ -65,6 +88,16 @@ func newAuthServiceWithEnt(t *testing.T) (*service.AuthService, service.UserRepo
_, err = db.Exec("PRAGMA foreign_keys = ON")
require.NoError(t, err)
+ _, err = db.Exec(`
+CREATE TABLE IF NOT EXISTS user_provider_default_grants (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ user_id INTEGER NOT NULL,
+ provider_type TEXT NOT NULL,
+ grant_reason TEXT NOT NULL DEFAULT 'first_bind',
+ created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ UNIQUE(user_id, provider_type, grant_reason)
+)`)
+ require.NoError(t, err)
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
@@ -82,17 +115,17 @@ func newAuthServiceWithEnt(t *testing.T) (*service.AuthService, service.UserRepo
},
}
settingSvc := service.NewSettingService(&authIdentitySettingRepoStub{
- values: map[string]string{
- service.SettingKeyRegistrationEnabled: "true",
- },
+ values: settings,
}, cfg)
- svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, nil)
+ svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, defaultSubAssigner)
return svc, repo, client
}
func TestAuthServiceRegisterDualWritesEmailIdentity(t *testing.T) {
- svc, _, client := newAuthServiceWithEnt(t)
+ svc, _, client := newAuthServiceWithEnt(t, map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ }, nil)
ctx := context.Background()
token, user, err := svc.Register(ctx, "user@example.com", "password")
@@ -119,7 +152,9 @@ func TestAuthServiceRegisterDualWritesEmailIdentity(t *testing.T) {
}
func TestAuthServiceLoginTouchesLastLoginAt(t *testing.T) {
- svc, repo, client := newAuthServiceWithEnt(t)
+ svc, repo, client := newAuthServiceWithEnt(t, map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ }, nil)
ctx := context.Background()
user := &service.User{
@@ -163,7 +198,9 @@ func TestAuthServiceLoginTouchesLastLoginAt(t *testing.T) {
}
func TestAuthServiceRecordSuccessfulLoginBackfillsEmailIdentity(t *testing.T) {
- svc, repo, client := newAuthServiceWithEnt(t)
+ svc, repo, client := newAuthServiceWithEnt(t, map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ }, nil)
ctx := context.Background()
user := &service.User{
@@ -188,3 +225,135 @@ func TestAuthServiceRecordSuccessfulLoginBackfillsEmailIdentity(t *testing.T) {
require.NoError(t, err)
require.Equal(t, user.ID, identity.UserID)
}
+
+func TestAuthServiceLogin_AppliesEmailFirstBindDefaultsOnlyWhenEmailIdentityIsNew(t *testing.T) {
+ assigner := &authIdentityDefaultSubAssignerStub{}
+ svc, _, client := newAuthServiceWithEnt(t, map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
+ service.SettingKeyAuthSourceDefaultEmailConcurrency: "4",
+ service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
+ service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
+ }, assigner)
+ ctx := context.Background()
+
+ passwordHash, err := svc.HashPassword("password")
+ require.NoError(t, err)
+ user, err := client.User.Create().
+ SetEmail("legacy@example.com").
+ SetUsername("legacy-user").
+ SetPasswordHash(passwordHash).
+ SetBalance(1.5).
+ SetConcurrency(2).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ token, gotUser, err := svc.Login(ctx, user.Email, "password")
+ require.NoError(t, err)
+ require.NotEmpty(t, token)
+ require.NotNil(t, gotUser)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, 10.0, storedUser.Balance)
+ require.Equal(t, 6, storedUser.Concurrency)
+ require.Len(t, assigner.calls, 1)
+ require.Equal(t, int64(11), assigner.calls[0].GroupID)
+ require.Equal(t, 30, assigner.calls[0].ValidityDays)
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("legacy@example.com"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, identityCount)
+ require.Equal(t, 1, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
+
+ token, gotUser, err = svc.Login(ctx, user.Email, "password")
+ require.NoError(t, err)
+ require.NotEmpty(t, token)
+ require.NotNil(t, gotUser)
+
+ storedUser, err = client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, 10.0, storedUser.Balance)
+ require.Equal(t, 6, storedUser.Concurrency)
+ require.Len(t, assigner.calls, 1)
+ require.Equal(t, 1, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
+}
+
+func TestAuthServiceLogin_DoesNotApplyEmailFirstBindDefaultsWhenIdentityAlreadyExists(t *testing.T) {
+ assigner := &authIdentityDefaultSubAssignerStub{}
+ svc, _, client := newAuthServiceWithEnt(t, map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
+ service.SettingKeyAuthSourceDefaultEmailConcurrency: "4",
+ service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
+ service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
+ }, assigner)
+ ctx := context.Background()
+
+ passwordHash, err := svc.HashPassword("password")
+ require.NoError(t, err)
+ user, err := client.User.Create().
+ SetEmail("bound@example.com").
+ SetUsername("bound-user").
+ SetPasswordHash(passwordHash).
+ SetBalance(2).
+ SetConcurrency(3).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+ _, err = client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("email").
+ SetProviderKey("email").
+ SetProviderSubject("bound@example.com").
+ SetVerifiedAt(time.Now().UTC()).
+ SetMetadata(map[string]any{"source": "preexisting"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ token, gotUser, err := svc.Login(ctx, user.Email, "password")
+ require.NoError(t, err)
+ require.NotEmpty(t, token)
+ require.NotNil(t, gotUser)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, 2.0, storedUser.Balance)
+ require.Equal(t, 3, storedUser.Concurrency)
+ require.Empty(t, assigner.calls)
+ require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
+}
+
+func countProviderGrantRecords(
+ t *testing.T,
+ client *dbent.Client,
+ userID int64,
+ providerType string,
+ grantReason string,
+) int {
+ t.Helper()
+
+ var count int
+ rows, err := client.QueryContext(
+ context.Background(),
+ `SELECT COUNT(*) FROM user_provider_default_grants WHERE user_id = ? AND provider_type = ? AND grant_reason = ?`,
+ userID,
+ providerType,
+ grantReason,
+ )
+ require.NoError(t, err)
+ defer rows.Close()
+ require.True(t, rows.Next())
+ require.NoError(t, rows.Scan(&count))
+ require.NoError(t, rows.Err())
+ return count
+}
--
GitLab
From 55e8dd550a4a261ab4bdf26f865da2eab11dafd8 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 00:33:23 +0800
Subject: [PATCH 100/261] Tighten WeChat payment resume flow
---
backend/internal/handler/auth_wechat_oauth.go | 36 +++++---
.../handler/auth_wechat_oauth_test.go | 61 +++++++++++++
backend/internal/handler/payment_handler.go | 69 ++++++++++++--
.../handler/payment_handler_resume_test.go | 61 +++++++++++++
.../internal/service/payment_resume_lookup.go | 4 +
.../service/payment_resume_service.go | 91 ++++++++++++++++---
.../service/payment_resume_service_test.go | 33 +++++++
frontend/src/types/payment.ts | 1 +
.../views/auth/WechatPaymentCallbackView.vue | 24 ++---
.../WechatPaymentCallbackView.spec.ts | 13 +--
frontend/src/views/user/PaymentResultView.vue | 3 +-
frontend/src/views/user/PaymentView.vue | 66 ++++++--------
.../user/__tests__/PaymentResultView.spec.ts | 19 ++++
.../__tests__/paymentWechatResume.spec.ts | 56 ++++++++++++
.../src/views/user/paymentWechatResume.ts | 77 ++++++++++++++++
15 files changed, 515 insertions(+), 99 deletions(-)
create mode 100644 backend/internal/handler/payment_handler_resume_test.go
create mode 100644 frontend/src/views/user/__tests__/paymentWechatResume.spec.ts
create mode 100644 frontend/src/views/user/paymentWechatResume.ts
diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go
index b078b804..45de30a8 100644
--- a/backend/internal/handler/auth_wechat_oauth.go
+++ b/backend/internal/handler/auth_wechat_oauth.go
@@ -435,24 +435,34 @@ func (h *AuthHandler) WeChatPaymentOAuthCallback(c *gin.Context) {
scope = strings.TrimSpace(tokenResp.Scope)
}
- fragment := url.Values{}
- fragment.Set("openid", openid)
- fragment.Set("state", state)
- fragment.Set("scope", scope)
- fragment.Set("payment_type", paymentContext.PaymentType)
- if paymentContext.Amount != "" {
- fragment.Set("amount", paymentContext.Amount)
- }
- if paymentContext.OrderType != "" {
- fragment.Set("order_type", paymentContext.OrderType)
- }
- if paymentContext.PlanID > 0 {
- fragment.Set("plan_id", strconv.FormatInt(paymentContext.PlanID, 10))
+ resumeToken, err := h.wechatPaymentResumeService().CreateWeChatPaymentResumeToken(service.WeChatPaymentResumeClaims{
+ OpenID: openid,
+ PaymentType: paymentContext.PaymentType,
+ Amount: paymentContext.Amount,
+ OrderType: paymentContext.OrderType,
+ PlanID: paymentContext.PlanID,
+ RedirectTo: redirectTo,
+ Scope: scope,
+ })
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "invalid_context", "failed to encode payment resume context", "")
+ return
}
+
+ fragment := url.Values{}
+ fragment.Set("wechat_resume_token", resumeToken)
fragment.Set("redirect", redirectTo)
redirectWithFragment(c, frontendCallback, fragment)
}
+func (h *AuthHandler) wechatPaymentResumeService() *service.PaymentResumeService {
+ key, err := payment.ProvideEncryptionKey(h.cfg)
+ if err != nil {
+ return service.NewPaymentResumeService(nil)
+ }
+ return service.NewPaymentResumeService([]byte(key))
+}
+
type completeWeChatOAuthRequest struct {
InvitationCode string `json:"invitation_code" binding:"required"`
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
diff --git a/backend/internal/handler/auth_wechat_oauth_test.go b/backend/internal/handler/auth_wechat_oauth_test.go
index def9d5d6..c65f4cd1 100644
--- a/backend/internal/handler/auth_wechat_oauth_test.go
+++ b/backend/internal/handler/auth_wechat_oauth_test.go
@@ -21,6 +21,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -175,6 +176,66 @@ func TestWeChatOAuthCallbackRejectsMissingUnionID(t *testing.T) {
require.Zero(t, count)
}
+func TestWeChatPaymentOAuthCallbackRedirectsWithOpaqueResumeToken(t *testing.T) {
+ t.Setenv("WECHAT_OAUTH_MP_APP_ID", "wx-mp-app")
+ t.Setenv("WECHAT_OAUTH_MP_APP_SECRET", "wx-mp-secret")
+
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if strings.Contains(r.URL.Path, "/sns/oauth2/access_token") {
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","scope":"snsapi_base"}`))
+ return
+ }
+ http.NotFound(w, r)
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+
+ handler, client := newWeChatOAuthTestHandler(t, false)
+ defer client.Close()
+ handler.cfg.Totp.EncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/payment/callback?code=wechat-code&state=state-123", nil)
+ req.Host = "api.example.com"
+ req.AddCookie(encodedCookie(wechatPaymentOAuthStateName, "state-123"))
+ req.AddCookie(encodedCookie(wechatPaymentOAuthRedirect, "/purchase?from=wechat"))
+ req.AddCookie(encodedCookie(wechatPaymentOAuthContextName, `{"payment_type":"wxpay","amount":"12.5","order_type":"subscription","plan_id":7}`))
+ req.AddCookie(encodedCookie(wechatPaymentOAuthScope, "snsapi_base"))
+ c.Request = req
+
+ handler.WeChatPaymentOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ location := recorder.Header().Get("Location")
+ parsed, err := url.Parse(location)
+ require.NoError(t, err)
+ fragment, err := url.ParseQuery(parsed.Fragment)
+ require.NoError(t, err)
+ require.Equal(t, "/purchase?from=wechat", fragment.Get("redirect"))
+ require.NotEmpty(t, fragment.Get("wechat_resume_token"))
+ require.Empty(t, fragment.Get("openid"))
+ require.Empty(t, fragment.Get("payment_type"))
+ require.Empty(t, fragment.Get("amount"))
+ require.Empty(t, fragment.Get("order_type"))
+ require.Empty(t, fragment.Get("plan_id"))
+
+ claims, err := handler.wechatPaymentResumeService().ParseWeChatPaymentResumeToken(fragment.Get("wechat_resume_token"))
+ require.NoError(t, err)
+ require.Equal(t, "openid-123", claims.OpenID)
+ require.Equal(t, payment.TypeWxpay, claims.PaymentType)
+ require.Equal(t, "12.5", claims.Amount)
+ require.Equal(t, payment.OrderTypeSubscription, claims.OrderType)
+ require.EqualValues(t, 7, claims.PlanID)
+ require.Equal(t, "/purchase?from=wechat", claims.RedirectTo)
+}
+
func TestWeChatOAuthCallbackBindUsesUnionCanonicalIdentityAcrossChannels(t *testing.T) {
testCases := []struct {
name string
diff --git a/backend/internal/handler/payment_handler.go b/backend/internal/handler/payment_handler.go
index d54cbe92..5fd6b43e 100644
--- a/backend/internal/handler/payment_handler.go
+++ b/backend/internal/handler/payment_handler.go
@@ -1,9 +1,12 @@
package handler
import (
+ "fmt"
"strconv"
"strings"
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
@@ -202,14 +205,15 @@ func (h *PaymentHandler) GetLimits(c *gin.Context) {
// CreateOrderRequest is the request body for creating a payment order.
type CreateOrderRequest struct {
- Amount float64 `json:"amount"`
- PaymentType string `json:"payment_type" binding:"required"`
- OpenID string `json:"openid"`
- ReturnURL string `json:"return_url"`
- PaymentSource string `json:"payment_source"`
- OrderType string `json:"order_type"`
- PlanID int64 `json:"plan_id"`
- IsMobile *bool `json:"is_mobile,omitempty"`
+ Amount float64 `json:"amount"`
+ PaymentType string `json:"payment_type" binding:"required"`
+ OpenID string `json:"openid"`
+ WechatResumeToken string `json:"wechat_resume_token"`
+ ReturnURL string `json:"return_url"`
+ PaymentSource string `json:"payment_source"`
+ OrderType string `json:"order_type"`
+ PlanID int64 `json:"plan_id"`
+ IsMobile *bool `json:"is_mobile,omitempty"`
}
// CreateOrder creates a new payment order.
@@ -225,6 +229,17 @@ func (h *PaymentHandler) CreateOrder(c *gin.Context) {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
+ if strings.TrimSpace(req.WechatResumeToken) != "" {
+ claims, err := h.paymentService.ParseWeChatPaymentResumeToken(req.WechatResumeToken)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := applyWeChatPaymentResumeClaims(&req, claims); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ }
mobile := isMobile(c)
if req.IsMobile != nil {
@@ -253,6 +268,44 @@ func (h *PaymentHandler) CreateOrder(c *gin.Context) {
response.Success(c, result)
}
+func applyWeChatPaymentResumeClaims(req *CreateOrderRequest, claims *service.WeChatPaymentResumeClaims) error {
+ if req == nil || claims == nil {
+ return infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume context is missing")
+ }
+ openid := strings.TrimSpace(claims.OpenID)
+ if openid == "" {
+ return infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token missing openid")
+ }
+
+ paymentType := service.NormalizeVisibleMethod(claims.PaymentType)
+ if paymentType == "" {
+ paymentType = payment.TypeWxpay
+ }
+ if req.PaymentType != "" {
+ requestPaymentType := service.NormalizeVisibleMethod(req.PaymentType)
+ if requestPaymentType != "" && requestPaymentType != paymentType {
+ return infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token payment type mismatch")
+ }
+ }
+ req.PaymentType = paymentType
+ req.OpenID = openid
+
+ if strings.TrimSpace(claims.Amount) != "" {
+ amount, err := strconv.ParseFloat(strings.TrimSpace(claims.Amount), 64)
+ if err != nil || amount <= 0 {
+ return infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", fmt.Sprintf("invalid resume amount: %s", claims.Amount))
+ }
+ req.Amount = amount
+ }
+ if claims.OrderType != "" {
+ req.OrderType = claims.OrderType
+ }
+ if claims.PlanID > 0 {
+ req.PlanID = claims.PlanID
+ }
+ return nil
+}
+
// GetMyOrders returns the authenticated user's orders.
// GET /api/v1/payment/orders/my
func (h *PaymentHandler) GetMyOrders(c *gin.Context) {
diff --git a/backend/internal/handler/payment_handler_resume_test.go b/backend/internal/handler/payment_handler_resume_test.go
new file mode 100644
index 00000000..323f7292
--- /dev/null
+++ b/backend/internal/handler/payment_handler_resume_test.go
@@ -0,0 +1,61 @@
+//go:build unit
+
+package handler
+
+import (
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+)
+
+func TestApplyWeChatPaymentResumeClaims(t *testing.T) {
+ t.Parallel()
+
+ req := CreateOrderRequest{
+ Amount: 0,
+ PaymentType: payment.TypeWxpay,
+ OrderType: payment.OrderTypeBalance,
+ }
+
+ err := applyWeChatPaymentResumeClaims(&req, &service.WeChatPaymentResumeClaims{
+ OpenID: "openid-123",
+ PaymentType: payment.TypeWxpay,
+ Amount: "12.50",
+ OrderType: payment.OrderTypeSubscription,
+ PlanID: 7,
+ })
+ if err != nil {
+ t.Fatalf("applyWeChatPaymentResumeClaims returned error: %v", err)
+ }
+ if req.OpenID != "openid-123" {
+ t.Fatalf("openid = %q, want %q", req.OpenID, "openid-123")
+ }
+ if req.Amount != 12.5 {
+ t.Fatalf("amount = %v, want 12.5", req.Amount)
+ }
+ if req.OrderType != payment.OrderTypeSubscription {
+ t.Fatalf("order_type = %q, want %q", req.OrderType, payment.OrderTypeSubscription)
+ }
+ if req.PlanID != 7 {
+ t.Fatalf("plan_id = %d, want 7", req.PlanID)
+ }
+}
+
+func TestApplyWeChatPaymentResumeClaimsRejectsPaymentTypeMismatch(t *testing.T) {
+ t.Parallel()
+
+ req := CreateOrderRequest{
+ PaymentType: payment.TypeAlipay,
+ }
+
+ err := applyWeChatPaymentResumeClaims(&req, &service.WeChatPaymentResumeClaims{
+ OpenID: "openid-123",
+ PaymentType: payment.TypeWxpay,
+ Amount: "12.50",
+ OrderType: payment.OrderTypeBalance,
+ })
+ if err == nil {
+ t.Fatal("applyWeChatPaymentResumeClaims should reject mismatched payment types")
+ }
+}
diff --git a/backend/internal/service/payment_resume_lookup.go b/backend/internal/service/payment_resume_lookup.go
index 493ca325..69033afd 100644
--- a/backend/internal/service/payment_resume_lookup.go
+++ b/backend/internal/service/payment_resume_lookup.go
@@ -33,3 +33,7 @@ func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token
return order, nil
}
+
+func (s *PaymentService) ParseWeChatPaymentResumeToken(token string) (*WeChatPaymentResumeClaims, error) {
+ return s.paymentResume().ParseWeChatPaymentResumeToken(strings.TrimSpace(token))
+}
diff --git a/backend/internal/service/payment_resume_service.go b/backend/internal/service/payment_resume_service.go
index 4f63645e..64d1d125 100644
--- a/backend/internal/service/payment_resume_service.go
+++ b/backend/internal/service/payment_resume_service.go
@@ -31,6 +31,8 @@ const (
VisibleMethodSourceEasyPayAlipay = "easypay_alipay"
VisibleMethodSourceOfficialWechat = "official_wxpay"
VisibleMethodSourceEasyPayWechat = "easypay_wxpay"
+
+ wechatPaymentResumeTokenType = "wechat_payment_resume"
)
type ResumeTokenClaims struct {
@@ -43,6 +45,18 @@ type ResumeTokenClaims struct {
IssuedAt int64 `json:"iat"`
}
+type WeChatPaymentResumeClaims struct {
+ TokenType string `json:"tk,omitempty"`
+ OpenID string `json:"openid"`
+ PaymentType string `json:"pt,omitempty"`
+ Amount string `json:"amt,omitempty"`
+ OrderType string `json:"ot,omitempty"`
+ PlanID int64 `json:"pid,omitempty"`
+ RedirectTo string `json:"rd,omitempty"`
+ Scope string `json:"scp,omitempty"`
+ IssuedAt int64 `json:"iat"`
+}
+
type PaymentResumeService struct {
signingKey []byte
}
@@ -232,6 +246,66 @@ func (s *PaymentResumeService) CreateToken(claims ResumeTokenClaims) (string, er
if claims.IssuedAt == 0 {
claims.IssuedAt = time.Now().Unix()
}
+ return s.createSignedToken(claims)
+}
+
+func (s *PaymentResumeService) ParseToken(token string) (*ResumeTokenClaims, error) {
+ var claims ResumeTokenClaims
+ if err := s.parseSignedToken(token, &claims); err != nil {
+ return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token payload is invalid")
+ }
+ if claims.OrderID <= 0 {
+ return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token missing order id")
+ }
+ return &claims, nil
+}
+
+func (s *PaymentResumeService) CreateWeChatPaymentResumeToken(claims WeChatPaymentResumeClaims) (string, error) {
+ claims.OpenID = strings.TrimSpace(claims.OpenID)
+ if claims.OpenID == "" {
+ return "", fmt.Errorf("wechat payment resume token requires openid")
+ }
+ if claims.IssuedAt == 0 {
+ claims.IssuedAt = time.Now().Unix()
+ }
+ if normalized := NormalizeVisibleMethod(claims.PaymentType); normalized != "" {
+ claims.PaymentType = normalized
+ }
+ if claims.PaymentType == "" {
+ claims.PaymentType = payment.TypeWxpay
+ }
+ if claims.OrderType == "" {
+ claims.OrderType = payment.OrderTypeBalance
+ }
+ claims.TokenType = wechatPaymentResumeTokenType
+ return s.createSignedToken(claims)
+}
+
+func (s *PaymentResumeService) ParseWeChatPaymentResumeToken(token string) (*WeChatPaymentResumeClaims, error) {
+ var claims WeChatPaymentResumeClaims
+ if err := s.parseSignedToken(token, &claims); err != nil {
+ return nil, infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token payload is invalid")
+ }
+ if claims.TokenType != wechatPaymentResumeTokenType {
+ return nil, infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token type mismatch")
+ }
+ claims.OpenID = strings.TrimSpace(claims.OpenID)
+ if claims.OpenID == "" {
+ return nil, infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token missing openid")
+ }
+ if normalized := NormalizeVisibleMethod(claims.PaymentType); normalized != "" {
+ claims.PaymentType = normalized
+ }
+ if claims.PaymentType == "" {
+ claims.PaymentType = payment.TypeWxpay
+ }
+ if claims.OrderType == "" {
+ claims.OrderType = payment.OrderTypeBalance
+ }
+ return &claims, nil
+}
+
+func (s *PaymentResumeService) createSignedToken(claims any) (string, error) {
payload, err := json.Marshal(claims)
if err != nil {
return "", fmt.Errorf("marshal resume claims: %w", err)
@@ -240,26 +314,19 @@ func (s *PaymentResumeService) CreateToken(claims ResumeTokenClaims) (string, er
return encodedPayload + "." + s.sign(encodedPayload), nil
}
-func (s *PaymentResumeService) ParseToken(token string) (*ResumeTokenClaims, error) {
+func (s *PaymentResumeService) parseSignedToken(token string, dest any) error {
parts := strings.Split(token, ".")
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
- return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token is malformed")
+ return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token is malformed")
}
if !hmac.Equal([]byte(parts[1]), []byte(s.sign(parts[0]))) {
- return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token signature mismatch")
+ return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token signature mismatch")
}
payload, err := base64.RawURLEncoding.DecodeString(parts[0])
if err != nil {
- return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token payload is malformed")
- }
- var claims ResumeTokenClaims
- if err := json.Unmarshal(payload, &claims); err != nil {
- return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token payload is invalid")
+ return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token payload is malformed")
}
- if claims.OrderID <= 0 {
- return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token missing order id")
- }
- return &claims, nil
+ return json.Unmarshal(payload, dest)
}
func (s *PaymentResumeService) sign(payload string) string {
diff --git a/backend/internal/service/payment_resume_service_test.go b/backend/internal/service/payment_resume_service_test.go
index 9c35ac3d..24d50494 100644
--- a/backend/internal/service/payment_resume_service_test.go
+++ b/backend/internal/service/payment_resume_service_test.go
@@ -150,6 +150,39 @@ func TestPaymentResumeTokenRoundTrip(t *testing.T) {
}
}
+func TestWeChatPaymentResumeTokenRoundTrip(t *testing.T) {
+ t.Parallel()
+
+ svc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
+ token, err := svc.CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{
+ OpenID: "openid-123",
+ PaymentType: payment.TypeWxpay,
+ Amount: "12.50",
+ OrderType: payment.OrderTypeSubscription,
+ PlanID: 7,
+ RedirectTo: "/purchase?from=wechat",
+ Scope: "snsapi_base",
+ IssuedAt: 1234567890,
+ })
+ if err != nil {
+ t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err)
+ }
+
+ claims, err := svc.ParseWeChatPaymentResumeToken(token)
+ if err != nil {
+ t.Fatalf("ParseWeChatPaymentResumeToken returned error: %v", err)
+ }
+ if claims.OpenID != "openid-123" || claims.PaymentType != payment.TypeWxpay {
+ t.Fatalf("claims mismatch: %+v", claims)
+ }
+ if claims.Amount != "12.50" || claims.OrderType != payment.OrderTypeSubscription || claims.PlanID != 7 {
+ t.Fatalf("claims payment context mismatch: %+v", claims)
+ }
+ if claims.RedirectTo != "/purchase?from=wechat" || claims.Scope != "snsapi_base" {
+ t.Fatalf("claims redirect/scope mismatch: %+v", claims)
+ }
+}
+
func TestNormalizeVisibleMethodSource(t *testing.T) {
t.Parallel()
diff --git a/frontend/src/types/payment.ts b/frontend/src/types/payment.ts
index 5cd49064..77fbe689 100644
--- a/frontend/src/types/payment.ts
+++ b/frontend/src/types/payment.ts
@@ -157,6 +157,7 @@ export interface CreateOrderRequest {
return_url?: string
payment_source?: string
openid?: string
+ wechat_resume_token?: string
is_mobile?: boolean
}
diff --git a/frontend/src/views/auth/WechatPaymentCallbackView.vue b/frontend/src/views/auth/WechatPaymentCallbackView.vue
index 422a0bb8..73095102 100644
--- a/frontend/src/views/auth/WechatPaymentCallbackView.vue
+++ b/frontend/src/views/auth/WechatPaymentCallbackView.vue
@@ -114,23 +114,17 @@ onMounted(async () => {
return
}
- const openid = readParam('openid')
- const state = readParam('state')
- const scope = readParam('scope')
- const paymentType = readParam('payment_type')
- const amount = readParam('amount')
- const orderType = readParam('order_type')
- const planId = readParam('plan_id')
+ const resumeToken = readParam('wechat_resume_token')
const redirectURL = new URL(
normalizeRedirectPath(readParam('redirect')),
window.location.origin,
)
- if (!openid) {
+ if (!resumeToken) {
errorMessage.value = textWithFallback(
- 'auth.wechatPayment.callbackMissingOpenId',
- '微信支付回调缺少 openid。',
- 'The WeChat payment callback is missing the openid.',
+ 'auth.wechatPayment.callbackMissingResumeToken',
+ '微信支付回调缺少恢复令牌。',
+ 'The WeChat payment callback is missing the resume token.',
)
return
}
@@ -138,14 +132,8 @@ onMounted(async () => {
const query: Record = {
...Object.fromEntries(redirectURL.searchParams.entries()),
wechat_resume: '1',
- openid,
+ wechat_resume_token: resumeToken,
}
- if (state) query.state = state
- if (scope) query.scope = scope
- if (paymentType) query.payment_type = paymentType
- if (amount) query.amount = amount
- if (orderType) query.order_type = orderType
- if (planId) query.plan_id = planId
await router.replace({
path: redirectURL.pathname,
diff --git a/frontend/src/views/auth/__tests__/WechatPaymentCallbackView.spec.ts b/frontend/src/views/auth/__tests__/WechatPaymentCallbackView.spec.ts
index cfbd9f1c..400e50d5 100644
--- a/frontend/src/views/auth/__tests__/WechatPaymentCallbackView.spec.ts
+++ b/frontend/src/views/auth/__tests__/WechatPaymentCallbackView.spec.ts
@@ -49,8 +49,8 @@ describe('WechatPaymentCallbackView', () => {
})
})
- it('redirects back to purchase with openid and payment context from hash fragment', async () => {
- locationState.current.hash = '#openid=openid-123&payment_type=wxpay&amount=12.5&order_type=balance&redirect=%2Fpurchase%3Ffrom%3Dwechat'
+ it('redirects back to purchase with an opaque resume token from hash fragment', async () => {
+ locationState.current.hash = '#wechat_resume_token=resume-token-123&redirect=%2Fpurchase%3Ffrom%3Dwechat'
mount(WechatPaymentCallbackView)
await flushPromises()
@@ -60,21 +60,18 @@ describe('WechatPaymentCallbackView', () => {
query: {
from: 'wechat',
wechat_resume: '1',
- openid: 'openid-123',
- payment_type: 'wxpay',
- amount: '12.5',
- order_type: 'balance',
+ wechat_resume_token: 'resume-token-123',
},
})
})
- it('shows an error when the callback payload is missing openid', async () => {
+ it('shows an error when the callback payload is missing the resume token', async () => {
locationState.current.hash = '#payment_type=wxpay'
const wrapper = mount(WechatPaymentCallbackView)
await flushPromises()
expect(replaceMock).not.toHaveBeenCalled()
- expect(wrapper.text()).toContain('微信支付回调缺少 openid。')
+ expect(wrapper.text()).toContain('微信支付回调缺少恢复令牌。')
})
})
diff --git a/frontend/src/views/user/PaymentResultView.vue b/frontend/src/views/user/PaymentResultView.vue
index e1bcbbe5..5c843d1f 100644
--- a/frontend/src/views/user/PaymentResultView.vue
+++ b/frontend/src/views/user/PaymentResultView.vue
@@ -188,7 +188,8 @@ onMounted(async () => {
}
}
- const hasLegacyFallbackContext = Boolean(route.query.trade_status || route.query.money || route.query.type)
+ const hasLegacyFallbackContext = typeof route.query.trade_status === 'string'
+ && route.query.trade_status.trim() !== ''
if (!order.value && !resumeToken && !orderId && outTradeNo && hasLegacyFallbackContext) {
returnInfo.value = {
outTradeNo,
diff --git a/frontend/src/views/user/PaymentView.vue b/frontend/src/views/user/PaymentView.vue
index 019de16a..a3a81ccb 100644
--- a/frontend/src/views/user/PaymentView.vue
+++ b/frontend/src/views/user/PaymentView.vue
@@ -284,6 +284,7 @@ import PaymentStatusPanel from '@/components/payment/PaymentStatusPanel.vue'
import Icon from '@/components/icons/Icon.vue'
import type { PaymentMethodOption } from '@/components/payment/PaymentMethodSelector.vue'
import { describePaymentScenarioError } from './paymentUx'
+import { parseWechatResumeRoute, stripWechatResumeQuery } from './paymentWechatResume'
const { t } = useI18n()
const route = useRoute()
@@ -315,6 +316,7 @@ const paymentPhase = ref<'select' | 'paying'>('select')
interface CreateOrderOptions {
openid?: string
+ wechatResumeToken?: string
paymentType?: string
isResume?: boolean
}
@@ -344,13 +346,6 @@ function emptyPaymentState(): PaymentRecoverySnapshot {
}
}
-function readRouteQueryValue(value: unknown): string {
- if (Array.isArray(value)) {
- return typeof value[0] === 'string' ? value[0] : ''
- }
- return typeof value === 'string' ? value : ''
-}
-
function getWeixinJSBridge(): WeixinJSBridgeLike | undefined {
return (window as Window & { WeixinJSBridge?: WeixinJSBridgeLike }).WeixinJSBridge
}
@@ -637,6 +632,9 @@ async function createOrder(orderAmount: number, orderType: OrderType, planId?: n
if (options.openid) {
payload.openid = options.openid
}
+ if (options.wechatResumeToken) {
+ payload.wechat_resume_token = options.wechatResumeToken
+ }
payload.is_mobile = isMobileDevice()
const result = await paymentStore.createOrder(payload) as CreateOrderResult & { resume_token?: string }
@@ -744,44 +742,34 @@ function applyScenarioError(err: unknown, paymentMethod: string) {
}
async function resumeWechatPaymentFromQuery() {
- const openid = readRouteQueryValue(route.query.openid)
- if (readRouteQueryValue(route.query.wechat_resume) !== '1' || !openid) {
+ const resume = parseWechatResumeRoute(route.query, checkout.value.plans, validAmount.value)
+ if (!resume) {
return
}
- const paymentType = normalizeVisibleMethod(readRouteQueryValue(route.query.payment_type)) || 'wxpay'
- const orderType = readRouteQueryValue(route.query.order_type) === 'subscription' ? 'subscription' : 'balance'
- const planId = Number.parseInt(readRouteQueryValue(route.query.plan_id), 10)
- const rawAmount = Number.parseFloat(readRouteQueryValue(route.query.amount))
- const orderAmount = Number.isFinite(rawAmount) && rawAmount > 0
- ? rawAmount
- : (orderType === 'subscription'
- ? (checkout.value.plans.find(plan => plan.id === planId)?.price ?? 0)
- : validAmount.value)
-
- selectedMethod.value = paymentType
- if (orderType === 'balance' && orderAmount > 0) {
- amount.value = orderAmount
+ selectedMethod.value = resume.paymentType
+ if (resume.orderType === 'balance' && resume.orderAmount > 0) {
+ amount.value = resume.orderAmount
+ }
+ if (resume.orderType === 'subscription' && resume.planId) {
+ selectedPlan.value = checkout.value.plans.find(plan => plan.id === resume.planId) ?? null
}
- if (orderType === 'subscription' && Number.isFinite(planId) && planId > 0) {
- selectedPlan.value = checkout.value.plans.find(plan => plan.id === planId) ?? null
+
+ await router.replace({ path: route.path, query: stripWechatResumeQuery(route.query) })
+
+ if (resume.wechatResumeToken) {
+ await createOrder(0, resume.orderType, resume.planId, {
+ wechatResumeToken: resume.wechatResumeToken,
+ paymentType: resume.paymentType,
+ isResume: true,
+ })
+ return
}
- const nextQuery = { ...route.query }
- delete nextQuery.wechat_resume
- delete nextQuery.openid
- delete nextQuery.state
- delete nextQuery.scope
- delete nextQuery.payment_type
- delete nextQuery.amount
- delete nextQuery.order_type
- delete nextQuery.plan_id
- await router.replace({ path: route.path, query: nextQuery })
-
- if (orderAmount > 0) {
- await createOrder(orderAmount, orderType, Number.isFinite(planId) && planId > 0 ? planId : undefined, {
- openid,
- paymentType,
+ if (resume.orderAmount > 0 && resume.openid) {
+ await createOrder(resume.orderAmount, resume.orderType, resume.planId, {
+ openid: resume.openid,
+ paymentType: resume.paymentType,
isResume: true,
})
}
diff --git a/frontend/src/views/user/__tests__/PaymentResultView.spec.ts b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts
index bfc044a7..d8199e3b 100644
--- a/frontend/src/views/user/__tests__/PaymentResultView.spec.ts
+++ b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts
@@ -157,6 +157,25 @@ describe('PaymentResultView', () => {
expect(wrapper.text()).toContain('payment.result.success')
})
+ it('does not use public out_trade_no verification for bare order numbers without legacy return markers', async () => {
+ routeState.query = {
+ out_trade_no: 'legacy-bare',
+ }
+
+ mount(PaymentResultView, {
+ global: {
+ stubs: {
+ OrderStatusBadge: true,
+ },
+ },
+ })
+
+ await flushPromises()
+
+ expect(verifyOrderPublic).not.toHaveBeenCalled()
+ expect(verifyOrder).not.toHaveBeenCalled()
+ })
+
it('resolves order by resume token when local recovery snapshot is missing', async () => {
routeState.query = {
resume_token: 'resume-77',
diff --git a/frontend/src/views/user/__tests__/paymentWechatResume.spec.ts b/frontend/src/views/user/__tests__/paymentWechatResume.spec.ts
new file mode 100644
index 00000000..c850ec1b
--- /dev/null
+++ b/frontend/src/views/user/__tests__/paymentWechatResume.spec.ts
@@ -0,0 +1,56 @@
+import { describe, expect, it } from 'vitest'
+import { parseWechatResumeRoute, stripWechatResumeQuery } from '../paymentWechatResume'
+
+describe('parseWechatResumeRoute', () => {
+ it('prefers the opaque resume token over legacy openid query params', () => {
+ expect(parseWechatResumeRoute({
+ wechat_resume: '1',
+ wechat_resume_token: 'resume-token-123',
+ openid: 'openid-123',
+ payment_type: 'wxpay',
+ amount: '12.5',
+ order_type: 'subscription',
+ plan_id: '7',
+ }, [], 88)).toEqual({
+ wechatResumeToken: 'resume-token-123',
+ paymentType: 'wxpay',
+ orderType: 'balance',
+ orderAmount: 0,
+ })
+ })
+
+ it('falls back to legacy openid-based resume when opaque token is absent', () => {
+ expect(parseWechatResumeRoute({
+ wechat_resume: '1',
+ openid: 'openid-123',
+ payment_type: 'wxpay',
+ amount: '12.5',
+ order_type: 'balance',
+ }, [], 88)).toEqual({
+ openid: 'openid-123',
+ paymentType: 'wxpay',
+ orderType: 'balance',
+ orderAmount: 12.5,
+ planId: undefined,
+ })
+ })
+})
+
+describe('stripWechatResumeQuery', () => {
+ it('removes both opaque-token and legacy resume params from the route query', () => {
+ expect(stripWechatResumeQuery({
+ foo: 'bar',
+ wechat_resume: '1',
+ wechat_resume_token: 'resume-token-123',
+ openid: 'openid-123',
+ payment_type: 'wxpay',
+ amount: '12.5',
+ order_type: 'subscription',
+ plan_id: '7',
+ state: 'state-123',
+ scope: 'snsapi_base',
+ })).toEqual({
+ foo: 'bar',
+ })
+ })
+})
diff --git a/frontend/src/views/user/paymentWechatResume.ts b/frontend/src/views/user/paymentWechatResume.ts
new file mode 100644
index 00000000..f53c8457
--- /dev/null
+++ b/frontend/src/views/user/paymentWechatResume.ts
@@ -0,0 +1,77 @@
+import type { LocationQuery, LocationQueryRaw } from 'vue-router'
+import type { SubscriptionPlan } from '@/types/payment'
+import { normalizeVisibleMethod } from '@/components/payment/paymentFlow'
+
+export interface ParsedWechatResumeRoute {
+ orderAmount: number
+ orderType: 'balance' | 'subscription'
+ paymentType: string
+ planId?: number
+ openid?: string
+ wechatResumeToken?: string
+}
+
+function readQueryString(query: LocationQuery, key: string): string {
+ const value = query[key]
+ if (Array.isArray(value)) {
+ return typeof value[0] === 'string' ? value[0] : ''
+ }
+ return typeof value === 'string' ? value : ''
+}
+
+export function parseWechatResumeRoute(
+ query: LocationQuery,
+ plans: SubscriptionPlan[],
+ fallbackBalanceAmount: number,
+): ParsedWechatResumeRoute | null {
+ if (readQueryString(query, 'wechat_resume') !== '1') {
+ return null
+ }
+
+ const wechatResumeToken = readQueryString(query, 'wechat_resume_token')
+ if (wechatResumeToken) {
+ return {
+ wechatResumeToken,
+ paymentType: 'wxpay',
+ orderType: 'balance',
+ orderAmount: 0,
+ }
+ }
+
+ const openid = readQueryString(query, 'openid')
+ if (!openid) {
+ return null
+ }
+
+ const paymentType = normalizeVisibleMethod(readQueryString(query, 'payment_type')) || 'wxpay'
+ const orderType = readQueryString(query, 'order_type') === 'subscription' ? 'subscription' : 'balance'
+ const planId = Number.parseInt(readQueryString(query, 'plan_id'), 10)
+ const rawAmount = Number.parseFloat(readQueryString(query, 'amount'))
+ const orderAmount = Number.isFinite(rawAmount) && rawAmount > 0
+ ? rawAmount
+ : (orderType === 'subscription'
+ ? (plans.find(plan => plan.id === planId)?.price ?? 0)
+ : fallbackBalanceAmount)
+
+ return {
+ openid,
+ paymentType,
+ orderType,
+ orderAmount,
+ planId: Number.isFinite(planId) && planId > 0 ? planId : undefined,
+ }
+}
+
+export function stripWechatResumeQuery(query: LocationQuery): LocationQueryRaw {
+ const nextQuery: LocationQueryRaw = { ...query }
+ delete nextQuery.wechat_resume
+ delete nextQuery.wechat_resume_token
+ delete nextQuery.openid
+ delete nextQuery.state
+ delete nextQuery.scope
+ delete nextQuery.payment_type
+ delete nextQuery.amount
+ delete nextQuery.order_type
+ delete nextQuery.plan_id
+ return nextQuery
+}
--
GitLab
From 030da8c2f66bdf031c245867322660b44776ed8d Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 00:41:29 +0800
Subject: [PATCH 101/261] fix: close admin settings review gaps
---
.../settings.paymentVisibleMethods.spec.ts | 8 ++---
frontend/src/api/admin/settings.ts | 4 +--
frontend/src/components/layout/AppSidebar.vue | 6 ++++
.../layout/__tests__/AppSidebar.spec.ts | 6 ++++
.../AuthIdentityMigrationReportsView.vue | 30 +++++++++++++++----
frontend/src/views/admin/SettingsView.vue | 30 +++++++++++++++++--
.../AuthIdentityMigrationReportsView.spec.ts | 17 +++++++++++
.../admin/__tests__/SettingsView.spec.ts | 23 ++++++++++++++
8 files changed, 110 insertions(+), 14 deletions(-)
diff --git a/frontend/src/api/__tests__/settings.paymentVisibleMethods.spec.ts b/frontend/src/api/__tests__/settings.paymentVisibleMethods.spec.ts
index 3b1a373f..ad355afe 100644
--- a/frontend/src/api/__tests__/settings.paymentVisibleMethods.spec.ts
+++ b/frontend/src/api/__tests__/settings.paymentVisibleMethods.spec.ts
@@ -27,8 +27,8 @@ describe('admin settings payment visible method helpers', () => {
expect(getPaymentVisibleMethodSourceOptions('alipay')).toEqual([
{
value: '',
- labelZh: '自动路由',
- labelEn: 'Automatic routing',
+ labelZh: '未配置',
+ labelEn: 'Not configured',
},
{
value: 'official_alipay',
@@ -45,8 +45,8 @@ describe('admin settings payment visible method helpers', () => {
expect(getPaymentVisibleMethodSourceOptions('wxpay')).toEqual([
{
value: '',
- labelZh: '自动路由',
- labelEn: 'Automatic routing',
+ labelZh: '未配置',
+ labelEn: 'Not configured',
},
{
value: 'official_wxpay',
diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts
index 505fcdca..235bda7b 100644
--- a/frontend/src/api/admin/settings.ts
+++ b/frontend/src/api/admin/settings.ts
@@ -44,12 +44,12 @@ const PAYMENT_VISIBLE_METHOD_SOURCE_OPTIONS: Record<
PaymentVisibleMethodSourceOption[]
> = {
alipay: [
- { value: '', labelZh: '自动路由', labelEn: 'Automatic routing' },
+ { value: '', labelZh: '未配置', labelEn: 'Not configured' },
{ value: 'official_alipay', labelZh: '支付宝官方', labelEn: 'Official Alipay' },
{ value: 'easypay_alipay', labelZh: '易支付支付宝', labelEn: 'EasyPay Alipay' },
],
wxpay: [
- { value: '', labelZh: '自动路由', labelEn: 'Automatic routing' },
+ { value: '', labelZh: '未配置', labelEn: 'Not configured' },
{ value: 'official_wxpay', labelZh: '微信官方', labelEn: 'Official WeChat Pay' },
{ value: 'easypay_wxpay', labelZh: '易支付微信', labelEn: 'EasyPay WeChat Pay' },
],
diff --git a/frontend/src/components/layout/AppSidebar.vue b/frontend/src/components/layout/AppSidebar.vue
index 92dcc519..b7158d9b 100644
--- a/frontend/src/components/layout/AppSidebar.vue
+++ b/frontend/src/components/layout/AppSidebar.vue
@@ -663,6 +663,12 @@ const adminNavItems = computed((): NavItem[] => {
? [{ path: '/admin/ops', label: t('nav.ops'), icon: ChartIcon }]
: []),
{ path: '/admin/users', label: t('nav.users'), icon: UsersIcon, hideInSimpleMode: true },
+ {
+ path: '/admin/users/auth-identity-migration-reports',
+ label: 'Migration Reports',
+ icon: UsersIcon,
+ hideInSimpleMode: true
+ },
{ path: '/admin/groups', label: t('nav.groups'), icon: FolderIcon, hideInSimpleMode: true },
{ path: '/admin/channels', label: t('nav.channels', '渠道管理'), icon: ChannelIcon, hideInSimpleMode: true },
{ path: '/admin/subscriptions', label: t('nav.subscriptions'), icon: CreditCardIcon, hideInSimpleMode: true },
diff --git a/frontend/src/components/layout/__tests__/AppSidebar.spec.ts b/frontend/src/components/layout/__tests__/AppSidebar.spec.ts
index 118c7615..915a67f8 100644
--- a/frontend/src/components/layout/__tests__/AppSidebar.spec.ts
+++ b/frontend/src/components/layout/__tests__/AppSidebar.spec.ts
@@ -30,3 +30,9 @@ describe('AppSidebar header styles', () => {
expect(sidebarBrandBlockMatch?.[0]).not.toContain('overflow: hidden;')
})
})
+
+describe('AppSidebar admin navigation', () => {
+ it('includes a visible entry for auth identity migration reports', () => {
+ expect(componentSource).toContain("'/admin/users/auth-identity-migration-reports'")
+ })
+})
diff --git a/frontend/src/views/admin/AuthIdentityMigrationReportsView.vue b/frontend/src/views/admin/AuthIdentityMigrationReportsView.vue
index 5aeb6b28..35c232b6 100644
--- a/frontend/src/views/admin/AuthIdentityMigrationReportsView.vue
+++ b/frontend/src/views/admin/AuthIdentityMigrationReportsView.vue
@@ -318,6 +318,7 @@ const pagination = reactive({
pageSize: 20,
total: 0,
})
+const knownReportTypes = ref([])
const columns: Column[] = [
{ key: 'status', label: text('状态', 'Status') },
@@ -330,12 +331,16 @@ const columns: Column[] = [
]
const reportTypeOptions = computed(() =>
- Object.entries(summary.value.by_type)
- .sort(([left], [right]) => left.localeCompare(right))
- .map(([value, count]) => ({
- value,
- label: `${value} (${count})`,
- }))
+ knownReportTypes.value
+ .slice()
+ .sort((left, right) => left.localeCompare(right))
+ .map((value) => {
+ const count = summary.value.by_type[value]
+ return {
+ value,
+ label: count === undefined ? value : `${value} (${count})`,
+ }
+ })
)
const canResolve = computed(() =>
@@ -347,10 +352,22 @@ const canResolve = computed(() =>
)
)
+const mergeKnownReportTypes = (...values: Array) => {
+ const merged = new Set(knownReportTypes.value)
+ for (const value of values) {
+ const normalized = value?.trim()
+ if (normalized) {
+ merged.add(normalized)
+ }
+ }
+ knownReportTypes.value = Array.from(merged)
+}
+
const loadSummary = async () => {
summaryLoading.value = true
try {
summary.value = await adminAPI.users.getAuthIdentityMigrationReportSummary()
+ mergeKnownReportTypes(...Object.keys(summary.value.by_type))
} catch (error) {
console.error('Failed to load auth identity migration report summary:', error)
appStore.showError(text('加载 migration reports 汇总失败', 'Failed to load migration report summary'))
@@ -370,6 +387,7 @@ const loadReports = async () => {
reports.value = response.items
pagination.total = response.total
+ mergeKnownReportTypes(filters.reportType, ...response.items.map((report) => report.report_type))
if (selectedReport.value) {
const refreshed = response.items.find((report) => report.id === selectedReport.value?.id) ?? null
diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue
index 8a042e70..9fb8da41 100644
--- a/frontend/src/views/admin/SettingsView.vue
+++ b/frontend/src/views/admin/SettingsView.vue
@@ -2728,8 +2728,8 @@
{{
localText(
- '留空表示自动路由;仅允许当前系统支持的官方或易支付来源。',
- 'Leave blank for automatic routing. Only supported official or EasyPay sources are allowed.'
+ '启用后必须明确选择一个来源;未配置状态不会对外展示该支付方式。',
+ 'Choose an explicit source before enabling the method. Not configured methods are not exposed.'
)
}}
@@ -3450,6 +3450,28 @@ function setPaymentVisibleMethodSource(
form.payment_visible_method_wxpay_source = normalized
}
+function validatePaymentVisibleMethodSelections(): boolean {
+ for (const visibleMethod of paymentVisibleMethodCards.value) {
+ if (!getPaymentVisibleMethodEnabled(visibleMethod.key)) {
+ continue
+ }
+
+ if (getPaymentVisibleMethodSource(visibleMethod.key)) {
+ continue
+ }
+
+ appStore.showError(
+ localText(
+ `${visibleMethod.title} 已启用,请先选择支付来源`,
+ `Select a payment source before enabling ${visibleMethod.title}`
+ )
+ )
+ return false
+ }
+
+ return true
+}
+
// Proxies for web search emulation ProxySelector
const webSearchProxies = ref([])
@@ -3979,6 +4001,10 @@ async function saveSettings() {
}
}
+ if (!validatePaymentVisibleMethodSelections()) {
+ return
+ }
+
// Validate URL fields — novalidate disables browser-native checks, so we validate here
const isValidHttpUrl = (url: string): boolean => {
if (!url) return true
diff --git a/frontend/src/views/admin/__tests__/AuthIdentityMigrationReportsView.spec.ts b/frontend/src/views/admin/__tests__/AuthIdentityMigrationReportsView.spec.ts
index 5e6b0ae0..406baaf1 100644
--- a/frontend/src/views/admin/__tests__/AuthIdentityMigrationReportsView.spec.ts
+++ b/frontend/src/views/admin/__tests__/AuthIdentityMigrationReportsView.spec.ts
@@ -240,4 +240,21 @@ describe('AuthIdentityMigrationReportsView', () => {
reportType: '',
})
})
+
+ it('keeps report type filter options available from list data when summary fails', async () => {
+ getAuthIdentityMigrationReportSummary.mockRejectedValueOnce(new Error('summary failed'))
+ listAuthIdentityMigrationReports.mockResolvedValueOnce(listResponse)
+
+ const wrapper = mountView()
+
+ await flushPromises()
+
+ const options = wrapper
+ .get('[data-test="report-type-filter"]')
+ .findAll('option')
+ .map((node) => node.element.value)
+
+ expect(showError).toHaveBeenCalled()
+ expect(options).toContain('oidc_synthetic_email_requires_manual_recovery')
+ })
})
diff --git a/frontend/src/views/admin/__tests__/SettingsView.spec.ts b/frontend/src/views/admin/__tests__/SettingsView.spec.ts
index f20170e9..b6f8ab17 100644
--- a/frontend/src/views/admin/__tests__/SettingsView.spec.ts
+++ b/frontend/src/views/admin/__tests__/SettingsView.spec.ts
@@ -449,4 +449,27 @@ describe('admin SettingsView payment visible method controls', () => {
})
)
})
+
+ it('blocks saving when a visible payment method is enabled without a source', async () => {
+ const wrapper = mountView()
+
+ await flushPromises()
+ await openPaymentTab(wrapper)
+
+ const paymentSourceSelects = wrapper
+ .findAll('select.select-stub')
+ .filter((node) => ['alipay', 'wxpay'].includes(node.attributes('data-placeholder')))
+
+ const alipaySelect = paymentSourceSelects.find(
+ (node) => node.attributes('data-placeholder') === 'alipay'
+ )
+
+ await alipaySelect?.setValue('')
+ await wrapper.find('form').trigger('submit.prevent')
+ await flushPromises()
+
+ expect(updateSettings).not.toHaveBeenCalled()
+ expect(showError).toHaveBeenCalled()
+ expect(String(showError.mock.calls.at(-1)?.[0] ?? '')).toContain('支付来源')
+ })
})
--
GitLab
From e4fe9fae2a5e8992a822f58334660d9f1aaa29a1 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 00:42:55 +0800
Subject: [PATCH 102/261] Fix profile refresh identity compatibility
---
.../handler/auth_current_user_test.go | 78 ++++++++++++++++
backend/internal/handler/auth_handler.go | 13 ++-
.../handler/auth_oauth_pending_flow.go | 18 +++-
.../handler/auth_oauth_pending_flow_test.go | 72 +++++++++++++++
backend/internal/handler/user_handler.go | 88 ++-----------------
backend/internal/handler/user_handler_test.go | 13 +--
backend/internal/service/user_service.go | 5 ++
7 files changed, 195 insertions(+), 92 deletions(-)
create mode 100644 backend/internal/handler/auth_current_user_test.go
diff --git a/backend/internal/handler/auth_current_user_test.go b/backend/internal/handler/auth_current_user_test.go
new file mode 100644
index 00000000..dab95e29
--- /dev/null
+++ b/backend/internal/handler/auth_current_user_test.go
@@ -0,0 +1,78 @@
+//go:build unit
+
+package handler
+
+import (
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+func TestAuthHandlerGetCurrentUserReturnsProfileCompatibilityFields(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ verifiedAt := time.Date(2026, 4, 20, 8, 30, 0, 0, time.UTC)
+ repo := &userHandlerRepoStub{
+ user: &service.User{
+ ID: 31,
+ Email: "me@example.com",
+ Username: "linuxdo-handle",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ AvatarURL: "https://cdn.example.com/linuxdo.png",
+ AvatarSource: "remote_url",
+ },
+ identities: []service.UserAuthIdentityRecord{
+ {
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: "linuxdo-subject-31",
+ VerifiedAt: &verifiedAt,
+ Metadata: map[string]any{
+ "username": "linuxdo-handle",
+ },
+ },
+ },
+ }
+
+ handler := &AuthHandler{
+ userService: service.NewUserService(repo, nil, nil, nil),
+ }
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/me", nil)
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 31})
+
+ handler.GetCurrentUser(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data map[string]any `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.Equal(t, true, resp.Data["email_bound"])
+ require.Equal(t, true, resp.Data["linuxdo_bound"])
+ require.Equal(t, "https://cdn.example.com/linuxdo.png", resp.Data["avatar_url"])
+
+ authBindings, ok := resp.Data["auth_bindings"].(map[string]any)
+ require.True(t, ok)
+ linuxdoBinding, ok := authBindings["linuxdo"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, true, linuxdoBinding["bound"])
+
+ _, hasAvatarSource := resp.Data["avatar_source"]
+ require.False(t, hasAvatarSource)
+ _, hasProfileSources := resp.Data["profile_sources"]
+ require.False(t, hasProfileSources)
+}
diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go
index b984a436..76ca153d 100644
--- a/backend/internal/handler/auth_handler.go
+++ b/backend/internal/handler/auth_handler.go
@@ -348,8 +348,14 @@ func (h *AuthHandler) GetCurrentUser(c *gin.Context) {
return
}
+ identities, err := h.userService.GetProfileIdentitySummaries(c.Request.Context(), subject.UserID, user)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
type UserResponse struct {
- *dto.User
+ userProfileResponse
RunMode string `json:"run_mode"`
}
@@ -358,7 +364,10 @@ func (h *AuthHandler) GetCurrentUser(c *gin.Context) {
runMode = h.cfg.RunMode
}
- response.Success(c, UserResponse{User: dto.UserFromService(user), RunMode: runMode})
+ response.Success(c, UserResponse{
+ userProfileResponse: userProfileResponseFromService(user, identities),
+ RunMode: runMode,
+ })
}
// ValidatePromoCodeRequest 验证优惠码请求
diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go
index 94186858..461810f1 100644
--- a/backend/internal/handler/auth_oauth_pending_flow.go
+++ b/backend/internal/handler/auth_oauth_pending_flow.go
@@ -848,6 +848,12 @@ func shouldBindPendingOAuthIdentity(session *dbent.PendingAuthSession, decision
}
}
+func shouldSkipAvatarAdoption(err error) bool {
+ return errors.Is(err, service.ErrAvatarInvalid) ||
+ errors.Is(err, service.ErrAvatarTooLarge) ||
+ errors.Is(err, service.ErrAvatarNotImage)
+}
+
func applyPendingOAuthBinding(
ctx context.Context,
client *dbent.Client,
@@ -885,6 +891,14 @@ func applyPendingOAuthBinding(
if decision != nil && decision.AdoptAvatar {
adoptedAvatarURL = pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_avatar_url")
}
+ shouldAdoptAvatar := false
+ if decision != nil && decision.AdoptAvatar && adoptedAvatarURL != "" {
+ if err := service.ValidateUserAvatar(adoptedAvatarURL); err == nil {
+ shouldAdoptAvatar = true
+ } else if !shouldSkipAvatarAdoption(err) {
+ return err
+ }
+ }
tx, err := client.Tx(ctx)
if err != nil {
@@ -913,7 +927,7 @@ func applyPendingOAuthBinding(
if decision != nil && decision.AdoptDisplayName && adoptedDisplayName != "" {
metadata["display_name"] = adoptedDisplayName
}
- if decision != nil && decision.AdoptAvatar && adoptedAvatarURL != "" {
+ if shouldAdoptAvatar {
metadata["avatar_url"] = adoptedAvatarURL
}
@@ -939,7 +953,7 @@ func applyPendingOAuthBinding(
}
}
- if decision != nil && decision.AdoptAvatar && adoptedAvatarURL != "" && userService != nil {
+ if shouldAdoptAvatar && userService != nil {
if _, err := userService.SetAvatar(txCtx, targetUserID, adoptedAvatarURL); err != nil {
return err
}
diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go
index 2521186e..d29e4b88 100644
--- a/backend/internal/handler/auth_oauth_pending_flow_test.go
+++ b/backend/internal/handler/auth_oauth_pending_flow_test.go
@@ -173,6 +173,78 @@ func TestExchangePendingOAuthCompletionPreviewThenFinalizeAppliesAdoptionDecisio
require.NotNil(t, consumed.ConsumedAt)
}
+func TestExchangePendingOAuthCompletionSkipsInvalidAvatarAdoptionWithoutBlockingCompletion(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ userEntity, err := client.User.Create().
+ SetEmail("invalid-avatar@example.com").
+ SetUsername("legacy-name").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("pending-invalid-avatar-token").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("invalid-avatar-123").
+ SetTargetUserID(userEntity.ID).
+ SetResolvedEmail(userEntity.Email).
+ SetBrowserSessionKey("browser-invalid-avatar-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "linuxdo_user",
+ "suggested_display_name": "Alice Example",
+ "suggested_avatar_url": "/avatars/alice.png",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "access_token": "access-token",
+ "redirect": "/dashboard",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"adopt_display_name":true,"adopt_avatar":true}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-invalid-avatar-key")})
+ ginCtx.Request = req
+
+ handler.ExchangePendingOAuthCompletion(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("linuxdo"),
+ authidentity.ProviderKeyEQ("linuxdo"),
+ authidentity.ProviderSubjectEQ("invalid-avatar-123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "Alice Example", identity.Metadata["display_name"])
+ _, hasAdoptedAvatar := identity.Metadata["avatar_url"]
+ require.False(t, hasAdoptedAvatar)
+
+ avatar := loadUserAvatarRecord(t, client, userEntity.ID)
+ require.Nil(t, avatar)
+
+ consumed, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, consumed.ConsumedAt)
+}
+
func TestExchangePendingOAuthCompletionBindCurrentUserPreviewThenFinalizeBindsIdentityWithoutAdoption(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go
index b1ade5c0..843b0bd9 100644
--- a/backend/internal/handler/user_handler.go
+++ b/backend/internal/handler/user_handler.go
@@ -2,7 +2,6 @@ package handler
import (
"context"
- "strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
@@ -353,22 +352,16 @@ func userProfileResponseFromService(user *service.User, identities service.UserI
return userProfileResponse{}
}
bindings := userProfileBindingMap(identities)
- profileSources, avatarSource, usernameSource := inferUserProfileSources(user, identities)
return userProfileResponse{
- User: *base,
- AvatarURL: user.AvatarURL,
- AvatarSource: avatarSource,
- UsernameSource: usernameSource,
- DisplayNameSource: usernameSource,
- NicknameSource: usernameSource,
- ProfileSources: profileSources,
- Identities: identities,
- AuthBindings: bindings,
- IdentityBindings: bindings,
- EmailBound: identities.Email.Bound,
- LinuxDoBound: identities.LinuxDo.Bound,
- OIDCBound: identities.OIDC.Bound,
- WeChatBound: identities.WeChat.Bound,
+ User: *base,
+ AvatarURL: user.AvatarURL,
+ Identities: identities,
+ AuthBindings: bindings,
+ IdentityBindings: bindings,
+ EmailBound: identities.Email.Bound,
+ LinuxDoBound: identities.LinuxDo.Bound,
+ OIDCBound: identities.OIDC.Bound,
+ WeChatBound: identities.WeChat.Bound,
}
}
@@ -380,66 +373,3 @@ func userProfileBindingMap(identities service.UserIdentitySummarySet) map[string
"wechat": identities.WeChat,
}
}
-
-func inferUserProfileSources(user *service.User, identities service.UserIdentitySummarySet) (
- map[string]*userProfileSourceContext,
- *userProfileSourceContext,
- *userProfileSourceContext,
-) {
- if user == nil {
- return nil, nil, nil
- }
-
- thirdParty := thirdPartyIdentityProviders(identities)
- var avatarSource *userProfileSourceContext
- if strings.TrimSpace(user.AvatarURL) != "" && len(thirdParty) == 1 {
- avatarSource = buildUserProfileSourceContext(thirdParty[0].Provider)
- }
-
- usernameValue := strings.TrimSpace(user.Username)
- var usernameSource *userProfileSourceContext
- for _, summary := range thirdParty {
- if usernameValue != "" && usernameValue == strings.TrimSpace(summary.DisplayName) {
- usernameSource = buildUserProfileSourceContext(summary.Provider)
- break
- }
- }
- if usernameSource == nil && usernameValue != "" && len(thirdParty) == 1 {
- usernameSource = buildUserProfileSourceContext(thirdParty[0].Provider)
- }
-
- profileSources := map[string]*userProfileSourceContext{}
- if avatarSource != nil {
- profileSources["avatar"] = avatarSource
- }
- if usernameSource != nil {
- profileSources["username"] = usernameSource
- profileSources["display_name"] = usernameSource
- profileSources["nickname"] = usernameSource
- }
- if len(profileSources) == 0 {
- return nil, avatarSource, usernameSource
- }
- return profileSources, avatarSource, usernameSource
-}
-
-func thirdPartyIdentityProviders(identities service.UserIdentitySummarySet) []service.UserIdentitySummary {
- out := make([]service.UserIdentitySummary, 0, 3)
- for _, summary := range []service.UserIdentitySummary{identities.LinuxDo, identities.OIDC, identities.WeChat} {
- if summary.Bound {
- out = append(out, summary)
- }
- }
- return out
-}
-
-func buildUserProfileSourceContext(provider string) *userProfileSourceContext {
- provider = strings.TrimSpace(provider)
- if provider == "" {
- return nil
- }
- return &userProfileSourceContext{
- Provider: provider,
- Source: provider,
- }
-}
diff --git a/backend/internal/handler/user_handler_test.go b/backend/internal/handler/user_handler_test.go
index 1216f9c4..7c6460e8 100644
--- a/backend/internal/handler/user_handler_test.go
+++ b/backend/internal/handler/user_handler_test.go
@@ -298,15 +298,10 @@ func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) {
require.True(t, ok)
require.Equal(t, true, emailBinding["bound"])
- avatarSource, ok := resp.Data["avatar_source"].(map[string]any)
- require.True(t, ok)
- require.Equal(t, "linuxdo", avatarSource["provider"])
-
- profileSources, ok := resp.Data["profile_sources"].(map[string]any)
- require.True(t, ok)
- usernameSource, ok := profileSources["username"].(map[string]any)
- require.True(t, ok)
- require.Equal(t, "linuxdo", usernameSource["provider"])
+ _, hasAvatarSource := resp.Data["avatar_source"]
+ require.False(t, hasAvatarSource)
+ _, hasProfileSources := resp.Data["profile_sources"]
+ require.False(t, hasProfileSources)
}
func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) {
diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go
index c106a3f5..cd1bc2bb 100644
--- a/backend/internal/service/user_service.go
+++ b/backend/internal/service/user_service.go
@@ -403,6 +403,11 @@ func normalizeUserAvatarInput(raw string) (UpsertUserAvatarInput, error) {
}, nil
}
+func ValidateUserAvatar(raw string) error {
+ _, err := normalizeUserAvatarInput(raw)
+ return err
+}
+
func normalizeInlineUserAvatarInput(raw string) (UpsertUserAvatarInput, error) {
body := strings.TrimPrefix(raw, "data:")
meta, encoded, ok := strings.Cut(body, ",")
--
GitLab
From 12f4af742f5c7c40503430f468c5e0e42894c72f Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 00:45:56 +0800
Subject: [PATCH 103/261] fix auth pending adoption and turnstile flow
---
.../auth/PendingOAuthCreateAccountForm.vue | 62 +++++++++++++-
.../PendingOAuthCreateAccountForm.spec.ts | 50 +++++++++++-
.../src/views/auth/LinuxDoCallbackView.vue | 28 ++++++-
frontend/src/views/auth/OidcCallbackView.vue | 28 ++++++-
.../__tests__/LinuxDoCallbackView.spec.ts | 81 ++++++++++++++++++-
.../auth/__tests__/OidcCallbackView.spec.ts | 78 +++++++++++++++++-
6 files changed, 313 insertions(+), 14 deletions(-)
diff --git a/frontend/src/components/auth/PendingOAuthCreateAccountForm.vue b/frontend/src/components/auth/PendingOAuthCreateAccountForm.vue
index 39588a86..36e78d36 100644
--- a/frontend/src/components/auth/PendingOAuthCreateAccountForm.vue
+++ b/frontend/src/components/auth/PendingOAuthCreateAccountForm.vue
@@ -16,6 +16,15 @@
placeholder="Password"
:disabled="isSubmitting"
/>
+
+
+
{{
@@ -80,9 +89,10 @@
diff --git a/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts b/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts
index 1c9531e3..7db2ecd7 100644
--- a/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts
+++ b/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts
@@ -1,6 +1,8 @@
import { mount } from '@vue/test-utils'
+import { createPinia, setActivePinia } from 'pinia'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import ProfileIdentityBindingsSection from '@/components/user/profile/ProfileIdentityBindingsSection.vue'
+import { useAppStore } from '@/stores'
import type { User } from '@/types'
const routeState = vi.hoisted(() => ({
@@ -11,6 +13,8 @@ const locationState = vi.hoisted(() => ({
current: { href: 'http://localhost/profile' } as { href: string },
}))
+let pinia: ReturnType
+
vi.mock('vue-router', () => ({
useRoute: () => routeState,
}))
@@ -57,6 +61,8 @@ function createUser(overrides: Partial = {}): User {
describe('ProfileIdentityBindingsSection', () => {
beforeEach(() => {
+ pinia = createPinia()
+ setActivePinia(pinia)
routeState.fullPath = '/profile'
locationState.current = { href: 'http://localhost/profile' }
Object.defineProperty(window, 'location', {
@@ -67,6 +73,9 @@ describe('ProfileIdentityBindingsSection', () => {
configurable: true,
value: 'Mozilla/5.0',
})
+ const appStore = useAppStore()
+ appStore.cachedPublicSettings = null
+ appStore.publicSettingsLoaded = false
})
afterEach(() => {
@@ -75,6 +84,9 @@ describe('ProfileIdentityBindingsSection', () => {
it('renders provider binding states and provider-specific bind actions', () => {
const wrapper = mount(ProfileIdentityBindingsSection, {
+ global: {
+ plugins: [pinia],
+ },
props: {
user: createUser({
auth_bindings: {
@@ -102,11 +114,16 @@ describe('ProfileIdentityBindingsSection', () => {
it('starts the WeChat bind flow for the current profile page', async () => {
const wrapper = mount(ProfileIdentityBindingsSection, {
+ global: {
+ plugins: [pinia],
+ },
props: {
user: createUser(),
linuxdoEnabled: false,
oidcEnabled: false,
wechatEnabled: true,
+ wechatOpenEnabled: true,
+ wechatMpEnabled: false,
},
})
@@ -117,4 +134,22 @@ describe('ProfileIdentityBindingsSection', () => {
expect(locationState.current.href).toContain('intent=bind_current_user')
expect(locationState.current.href).toContain('redirect=%2Fprofile')
})
+
+ it('hides the WeChat bind action outside the WeChat browser when only mp mode is configured', () => {
+ const wrapper = mount(ProfileIdentityBindingsSection, {
+ global: {
+ plugins: [pinia],
+ },
+ props: {
+ user: createUser(),
+ linuxdoEnabled: false,
+ oidcEnabled: false,
+ wechatEnabled: true,
+ wechatOpenEnabled: false,
+ wechatMpEnabled: true,
+ },
+ })
+
+ expect(wrapper.find('[data-testid="profile-binding-wechat-action"]').exists()).toBe(false)
+ })
})
diff --git a/frontend/src/stores/app.ts b/frontend/src/stores/app.ts
index a8e03a51..0ed3f37f 100644
--- a/frontend/src/stores/app.ts
+++ b/frontend/src/stores/app.ts
@@ -338,6 +338,8 @@ export const useAppStore = defineStore('app', () => {
custom_endpoints: [],
linuxdo_oauth_enabled: false,
wechat_oauth_enabled: false,
+ wechat_oauth_open_enabled: false,
+ wechat_oauth_mp_enabled: false,
oidc_oauth_enabled: false,
oidc_oauth_provider_name: 'OIDC',
backend_mode_enabled: false,
diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts
index 5a2e3184..07341919 100644
--- a/frontend/src/types/index.ts
+++ b/frontend/src/types/index.ts
@@ -165,6 +165,8 @@ export interface PublicSettings {
custom_endpoints: CustomEndpoint[]
linuxdo_oauth_enabled: boolean
wechat_oauth_enabled: boolean
+ wechat_oauth_open_enabled?: boolean
+ wechat_oauth_mp_enabled?: boolean
oidc_oauth_enabled: boolean
oidc_oauth_provider_name: string
backend_mode_enabled: boolean
diff --git a/frontend/src/views/auth/WechatCallbackView.vue b/frontend/src/views/auth/WechatCallbackView.vue
index 10b83b1c..a5da35e5 100644
--- a/frontend/src/views/auth/WechatCallbackView.vue
+++ b/frontend/src/views/auth/WechatCallbackView.vue
@@ -297,6 +297,7 @@ import {
login2FA,
prepareOAuthBindAccessTokenCookie,
persistOAuthTokenContext,
+ resolveWeChatOAuthStart,
type OAuthAdoptionDecision,
type PendingOAuthExchangeResponse
} from '@/api/auth'
@@ -378,7 +379,47 @@ function normalizeWeChatOAuthMode(value: unknown): 'open' | 'mp' | null {
return value === 'open' || value === 'mp' ? value : null
}
-function resolveRequestedWeChatOAuthMode(): 'open' | 'mp' {
+async function ensurePublicSettingsLoaded(): Promise {
+ if (appStore.cachedPublicSettings || appStore.publicSettingsLoaded) {
+ return
+ }
+
+ try {
+ await appStore.fetchPublicSettings()
+ } catch {
+ // Fall back to legacy mode selection when public settings are unavailable.
+ }
+}
+
+function resolveConfiguredWeChatOAuthMode(): 'open' | 'mp' | null {
+ if (!appStore.cachedPublicSettings && !appStore.publicSettingsLoaded) {
+ return null
+ }
+
+ return resolveWeChatOAuthStart(appStore.cachedPublicSettings).mode
+}
+
+function resolveWeChatOAuthUnavailableMessage(): string {
+ const resolved = resolveWeChatOAuthStart(appStore.cachedPublicSettings)
+
+ switch (resolved.unavailableReason) {
+ case 'external_browser_required':
+ return 'This WeChat sign-in flow is only available in your system browser.'
+ case 'wechat_browser_required':
+ return 'This WeChat sign-in flow is only available inside the WeChat browser.'
+ case 'not_configured':
+ return 'WeChat sign-in is not configured yet.'
+ default:
+ return t('auth.loginFailed')
+ }
+}
+
+function resolveRequestedWeChatOAuthMode(): 'open' | 'mp' | null {
+ const configuredMode = resolveConfiguredWeChatOAuthMode()
+ if (configuredMode) {
+ return configuredMode
+ }
+
const queryMode = normalizeWeChatOAuthMode(route.query.mode)
return queryMode || resolveWeChatOAuthMode()
}
@@ -389,11 +430,15 @@ function resolveRedirectTarget(): string {
)
}
-function resolveWeChatStartURL(intent: 'bind_current_user' | 'adopt_existing_user_by_email'): string {
+function resolveWeChatStartURL(intent: 'bind_current_user' | 'adopt_existing_user_by_email'): string | null {
const apiBase = (import.meta.env.VITE_API_BASE_URL as string | undefined) || '/api/v1'
const normalized = apiBase.replace(/\/$/, '')
+ const mode = resolveRequestedWeChatOAuthMode()
+ if (!mode) {
+ return null
+ }
const params = new URLSearchParams({
- mode: resolveRequestedWeChatOAuthMode(),
+ mode,
redirect: resolveRedirectTarget(),
intent,
})
@@ -406,11 +451,15 @@ function resolveWeChatStartURL(intent: 'bind_current_user' | 'adopt_existing_use
return `${normalized}/auth/oauth/wechat/start?${params.toString()}`
}
-function buildExistingAccountResumePath(): string {
+function buildExistingAccountResumePath(): string | null {
+ const mode = resolveRequestedWeChatOAuthMode()
+ if (!mode) {
+ return null
+ }
const params = new URLSearchParams({
wechat_bind_existing: '1',
redirect: resolveRedirectTarget(),
- mode: resolveRequestedWeChatOAuthMode(),
+ mode,
})
const email = existingAccountEmail.value.trim()
@@ -444,14 +493,31 @@ function serializeAdoptionDecision(decision: OAuthAdoptionDecision): Record {
+ await ensurePublicSettingsLoaded()
+
if (typeof route.query.email === 'string') {
existingAccountEmail.value = route.query.email
}
if (route.query.wechat_bind_existing === '1') {
if (getAuthToken()) {
+ const startURL = resolveWeChatStartURL('bind_current_user')
+ if (!startURL) {
+ errorMessage.value = resolveWeChatOAuthUnavailableMessage()
+ appStore.showError(errorMessage.value)
+ isProcessing.value = false
+ return
+ }
prepareOAuthBindAccessTokenCookie()
- window.location.href = resolveWeChatStartURL('bind_current_user')
+ window.location.href = startURL
+ return
+ }
+
+ const resumePath = buildExistingAccountResumePath()
+ if (!resumePath) {
+ errorMessage.value = resolveWeChatOAuthUnavailableMessage()
+ appStore.showError(errorMessage.value)
+ isProcessing.value = false
return
}
const params = new URLSearchParams({
- redirect: buildExistingAccountResumePath(),
+ redirect: resumePath,
})
const email = existingAccountEmail.value.trim()
if (email) {
diff --git a/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts
index aa673238..7f26f3c8 100644
--- a/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts
+++ b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts
@@ -14,8 +14,10 @@ const {
setTokenMock,
showSuccessMock,
showErrorMock,
+ fetchPublicSettingsMock,
routeState,
locationState,
+ appStoreState,
} = vi.hoisted(() => ({
exchangePendingOAuthCompletionMock: vi.fn(),
completeWeChatOAuthRegistrationMock: vi.fn(),
@@ -28,6 +30,7 @@ const {
setTokenMock: vi.fn(),
showSuccessMock: vi.fn(),
showErrorMock: vi.fn(),
+ fetchPublicSettingsMock: vi.fn(),
routeState: {
query: {} as Record,
},
@@ -39,6 +42,10 @@ const {
pathname: '/auth/wechat/callback'
} as { href: string; hash: string; search: string; pathname: string },
},
+ appStoreState: {
+ cachedPublicSettings: null as null | Record,
+ publicSettingsLoaded: false,
+ },
}))
vi.mock('vue-router', () => ({
@@ -102,8 +109,10 @@ vi.mock('@/stores', () => ({
setToken: setTokenMock,
}),
useAppStore: () => ({
+ ...appStoreState,
showSuccess: showSuccessMock,
showError: showErrorMock,
+ fetchPublicSettings: fetchPublicSettingsMock,
}),
}))
@@ -139,7 +148,10 @@ describe('WechatCallbackView', () => {
showErrorMock.mockReset()
prepareOAuthBindAccessTokenCookieMock.mockReset()
getAuthTokenMock.mockReset()
+ fetchPublicSettingsMock.mockReset()
routeState.query = {}
+ appStoreState.cachedPublicSettings = null
+ appStoreState.publicSettingsLoaded = false
localStorage.clear()
locationState.current = {
href: 'http://localhost/auth/wechat/callback',
@@ -157,6 +169,38 @@ describe('WechatCallbackView', () => {
})
})
+ it('overrides an incompatible query mode with the configured open capability during bind recovery', async () => {
+ routeState.query = {
+ wechat_bind_existing: '1',
+ mode: 'mp',
+ redirect: '/profile',
+ }
+ appStoreState.cachedPublicSettings = {
+ wechat_oauth_enabled: true,
+ wechat_oauth_open_enabled: true,
+ wechat_oauth_mp_enabled: false,
+ }
+ appStoreState.publicSettingsLoaded = true
+ getAuthTokenMock.mockReturnValue('current-auth-token')
+
+ mount(WechatCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false,
+ },
+ },
+ })
+
+ await flushPromises()
+
+ expect(prepareOAuthBindAccessTokenCookieMock).toHaveBeenCalledTimes(1)
+ expect(locationState.current.href).toContain('mode=open')
+ expect(locationState.current.href).not.toContain('mode=mp')
+ })
+
it('does not send adoption decisions during the initial exchange', async () => {
exchangePendingOAuthCompletionMock.mockResolvedValue({
access_token: 'access-token',
diff --git a/frontend/src/views/user/ProfileView.vue b/frontend/src/views/user/ProfileView.vue
index f7418be9..14d7efea 100644
--- a/frontend/src/views/user/ProfileView.vue
+++ b/frontend/src/views/user/ProfileView.vue
@@ -67,7 +67,6 @@
diff --git a/frontend/src/views/user/__tests__/PaymentResultView.spec.ts b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts
index d8199e3b..c4e38523 100644
--- a/frontend/src/views/user/__tests__/PaymentResultView.spec.ts
+++ b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts
@@ -1,4 +1,4 @@
-import { beforeEach, describe, expect, it, vi } from 'vitest'
+import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import { flushPromises, mount } from '@vue/test-utils'
const routeState = vi.hoisted(() => ({
@@ -73,7 +73,11 @@ describe('PaymentResultView', () => {
window.localStorage.clear()
})
- it('restores order id from a matching resume token and does not trust query success flags', async () => {
+ afterEach(() => {
+ vi.useRealTimers()
+ })
+
+ it('renders a pending state instead of a failure state when the restored order is still pending', async () => {
routeState.query = {
resume_token: 'resume-42',
order_id: '999',
@@ -107,8 +111,43 @@ describe('PaymentResultView', () => {
expect(pollOrderStatus).toHaveBeenCalledWith(42)
expect(verifyOrderPublic).not.toHaveBeenCalled()
- expect(wrapper.text()).toContain('payment.result.failed')
+ expect(wrapper.text()).toContain('payment.result.processing')
expect(wrapper.text()).not.toContain('payment.result.success')
+ expect(wrapper.text()).not.toContain('payment.result.failed')
+ })
+
+ it('refreshes a pending resume-token result until the order becomes paid', async () => {
+ vi.useFakeTimers()
+ routeState.query = {
+ resume_token: 'resume-77',
+ }
+ resolveOrderPublicByResumeToken
+ .mockResolvedValueOnce({
+ data: orderFactory('PENDING'),
+ })
+ .mockResolvedValueOnce({
+ data: orderFactory('PAID'),
+ })
+
+ const wrapper = mount(PaymentResultView, {
+ global: {
+ stubs: {
+ OrderStatusBadge: true,
+ },
+ },
+ })
+
+ await flushPromises()
+
+ expect(resolveOrderPublicByResumeToken).toHaveBeenCalledTimes(1)
+ expect(wrapper.text()).toContain('payment.result.processing')
+
+ await vi.advanceTimersByTimeAsync(2000)
+ await flushPromises()
+
+ expect(resolveOrderPublicByResumeToken).toHaveBeenCalledTimes(2)
+ expect(wrapper.text()).toContain('payment.result.success')
+ expect(wrapper.text()).not.toContain('payment.result.failed')
})
it('does not fall back to public out_trade_no verification when resume_token recovery fails', async () => {
--
GitLab
From 7a9488ff37b07d70ae4ad62ac1dab17e37bc7470 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 00:59:20 +0800
Subject: [PATCH 106/261] Add legacy identity safety remediation migration
---
...ntity_legacy_migration_integration_test.go | 246 +++++++++++++
...dentity_legacy_external_safety_reports.sql | 336 ++++++++++++++++++
2 files changed, 582 insertions(+)
create mode 100644 backend/migrations/116_auth_identity_legacy_external_safety_reports.sql
diff --git a/backend/internal/repository/auth_identity_legacy_migration_integration_test.go b/backend/internal/repository/auth_identity_legacy_migration_integration_test.go
index 6a6312d4..7f2f363f 100644
--- a/backend/internal/repository/auth_identity_legacy_migration_integration_test.go
+++ b/backend/internal/repository/auth_identity_legacy_migration_integration_test.go
@@ -200,6 +200,252 @@ FROM auth_identity_migration_reports
var afterCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
+FROM auth_identity_migration_reports
+ `).Scan(&afterCount))
+ require.Equal(t, beforeCount, afterCount)
+}
+
+func TestAuthIdentityLegacyExternalSafetyMigration_ReportsConflictsAndDowngradesInvalidJSON(t *testing.T) {
+ tx := testTx(t)
+ ctx := context.Background()
+
+ migrationPath := filepath.Join("..", "..", "migrations", "116_auth_identity_legacy_external_safety_reports.sql")
+ migrationSQL, err := os.ReadFile(migrationPath)
+ require.NoError(t, err)
+
+ _, err = tx.ExecContext(ctx, `
+CREATE TABLE IF NOT EXISTS user_external_identities (
+ id BIGSERIAL PRIMARY KEY,
+ user_id BIGINT NOT NULL,
+ provider TEXT NOT NULL,
+ provider_user_id TEXT NOT NULL,
+ provider_union_id TEXT NULL,
+ provider_username TEXT NOT NULL DEFAULT '',
+ display_name TEXT NOT NULL DEFAULT '',
+ profile_url TEXT NOT NULL DEFAULT '',
+ avatar_url TEXT NOT NULL DEFAULT '',
+ metadata TEXT NOT NULL DEFAULT '{}',
+ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
+);
+
+ TRUNCATE TABLE
+ auth_identity_channels,
+ auth_identities,
+ auth_identity_migration_reports,
+ user_external_identities,
+ users
+ RESTART IDENTITY;
+`)
+ require.NoError(t, err)
+
+ userIDs := make([]int64, 0, 8)
+ for _, email := range []string{
+ "linuxdo-conflict-legacy@example.com",
+ "linuxdo-conflict-owner@example.com",
+ "wechat-conflict-legacy@example.com",
+ "wechat-conflict-owner@example.com",
+ "wechat-channel-legacy@example.com",
+ "wechat-channel-owner@example.com",
+ "linuxdo-invalid-json@example.com",
+ "wechat-openid-invalid-json@example.com",
+ } {
+ var userID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ($1, 'hash', 'user', 'active', 0, 1)
+RETURNING id`, email).Scan(&userID))
+ userIDs = append(userIDs, userID)
+ }
+
+ linuxdoConflictLegacyUserID := userIDs[0]
+ linuxdoConflictOwnerUserID := userIDs[1]
+ wechatConflictLegacyUserID := userIDs[2]
+ wechatConflictOwnerUserID := userIDs[3]
+ wechatChannelLegacyUserID := userIDs[4]
+ wechatChannelOwnerUserID := userIDs[5]
+ linuxdoInvalidJSONUserID := userIDs[6]
+ wechatInvalidOpenIDUserID := userIDs[7]
+
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO auth_identities (user_id, provider_type, provider_key, provider_subject, metadata)
+VALUES ($1, 'linuxdo', 'linuxdo', 'linuxdo-conflict', '{}'::jsonb)
+RETURNING id`, linuxdoConflictOwnerUserID).Scan(new(int64)))
+
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO auth_identities (user_id, provider_type, provider_key, provider_subject, metadata)
+VALUES ($1, 'wechat', 'wechat-main', 'union-conflict', '{}'::jsonb)
+RETURNING id`, wechatConflictOwnerUserID).Scan(new(int64)))
+
+ var wechatChannelOwnerIdentityID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO auth_identities (user_id, provider_type, provider_key, provider_subject, metadata)
+VALUES ($1, 'wechat', 'wechat-main', 'union-channel-owner', '{}'::jsonb)
+RETURNING id`, wechatChannelOwnerUserID).Scan(&wechatChannelOwnerIdentityID))
+
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO auth_identity_channels (
+ identity_id,
+ provider_type,
+ provider_key,
+ channel,
+ channel_app_id,
+ channel_subject,
+ metadata
+)
+VALUES ($1, 'wechat', 'wechat-main', 'oa', 'wx-app-conflict', 'openid-channel-conflict', '{}'::jsonb)
+RETURNING id`, wechatChannelOwnerIdentityID).Scan(new(int64)))
+
+ var linuxdoConflictLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'linuxdo', 'linuxdo-conflict', NULL, 'legacy-linuxdo', 'Legacy LinuxDo Conflict', '{"source":"legacy"}')
+RETURNING id
+`, linuxdoConflictLegacyUserID).Scan(&linuxdoConflictLegacyID))
+
+ var wechatConflictLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'wechat', 'openid-union-conflict', 'union-conflict', 'legacy-wechat', 'Legacy WeChat Conflict', '{"channel":"oa","appid":"wx-app-conflict-canon"}')
+RETURNING id
+`, wechatConflictLegacyUserID).Scan(&wechatConflictLegacyID))
+
+ var wechatChannelConflictLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'wechat', 'openid-channel-conflict', 'union-channel-legacy', 'legacy-wechat-channel', 'Legacy WeChat Channel Conflict', '{"channel":"oa","appid":"wx-app-conflict"}')
+RETURNING id
+`, wechatChannelLegacyUserID).Scan(&wechatChannelConflictLegacyID))
+
+ var linuxdoInvalidJSONLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'linuxdo', 'linuxdo-invalid-json', NULL, 'legacy-linuxdo-invalid', 'Legacy LinuxDo Invalid JSON', '{invalid')
+RETURNING id
+`, linuxdoInvalidJSONUserID).Scan(&linuxdoInvalidJSONLegacyID))
+
+ var wechatInvalidOpenIDLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'wechat', 'openid-invalid-json-only', NULL, 'legacy-wechat-invalid', 'Legacy WeChat Invalid JSON', '{still-invalid')
+RETURNING id
+`, wechatInvalidOpenIDUserID).Scan(&wechatInvalidOpenIDLegacyID))
+
+ _, err = tx.ExecContext(ctx, string(migrationSQL))
+ require.NoError(t, err)
+
+ var linuxdoConflictReportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'legacy_external_identity_conflict'
+ AND report_key = $1
+`, "legacy_external_identity:"+strconv.FormatInt(linuxdoConflictLegacyID, 10)).Scan(&linuxdoConflictReportCount))
+ require.Equal(t, 1, linuxdoConflictReportCount)
+
+ var wechatConflictReportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'legacy_external_identity_conflict'
+ AND report_key = $1
+`, "legacy_external_identity:"+strconv.FormatInt(wechatConflictLegacyID, 10)).Scan(&wechatConflictReportCount))
+ require.Equal(t, 1, wechatConflictReportCount)
+
+ var channelConflictReportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'legacy_external_channel_conflict'
+ AND report_key = $1
+`, "legacy_external_identity:"+strconv.FormatInt(wechatChannelConflictLegacyID, 10)).Scan(&channelConflictReportCount))
+ require.Equal(t, 1, channelConflictReportCount)
+
+ var invalidJSONReportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'legacy_external_identity_invalid_metadata_json'
+ AND report_key IN ($1, $2)
+`, "legacy_external_identity:"+strconv.FormatInt(linuxdoInvalidJSONLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatInvalidOpenIDLegacyID, 10)).Scan(&invalidJSONReportCount))
+ require.Equal(t, 2, invalidJSONReportCount)
+
+ var linuxdoInvalidIdentityCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identities
+WHERE user_id = $1
+ AND provider_type = 'linuxdo'
+ AND provider_key = 'linuxdo'
+ AND provider_subject = 'linuxdo-invalid-json'
+`, linuxdoInvalidJSONUserID).Scan(&linuxdoInvalidIdentityCount))
+ require.Equal(t, 1, linuxdoInvalidIdentityCount)
+
+ var wechatOpenIDOnlyReportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'wechat_openid_only_requires_remediation'
+ AND report_key = $1
+`, "legacy_external_identity:"+strconv.FormatInt(wechatInvalidOpenIDLegacyID, 10)).Scan(&wechatOpenIDOnlyReportCount))
+ require.Equal(t, 1, wechatOpenIDOnlyReportCount)
+}
+
+func TestAuthIdentityLegacyExternalSafetyMigration_IsSafeWhenLegacyTableMissing(t *testing.T) {
+ tx := testTx(t)
+ ctx := context.Background()
+
+ migrationPath := filepath.Join("..", "..", "migrations", "116_auth_identity_legacy_external_safety_reports.sql")
+ migrationSQL, err := os.ReadFile(migrationPath)
+ require.NoError(t, err)
+
+ var beforeCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+`).Scan(&beforeCount))
+
+ _, err = tx.ExecContext(ctx, string(migrationSQL))
+ require.NoError(t, err)
+
+ var afterCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
FROM auth_identity_migration_reports
`).Scan(&afterCount))
require.Equal(t, beforeCount, afterCount)
diff --git a/backend/migrations/116_auth_identity_legacy_external_safety_reports.sql b/backend/migrations/116_auth_identity_legacy_external_safety_reports.sql
new file mode 100644
index 00000000..994f3f37
--- /dev/null
+++ b/backend/migrations/116_auth_identity_legacy_external_safety_reports.sql
@@ -0,0 +1,336 @@
+CREATE OR REPLACE FUNCTION public.__migration_116_safe_legacy_metadata_jsonb(input_text TEXT)
+RETURNS JSONB
+LANGUAGE plpgsql
+AS $$
+DECLARE
+ parsed JSONB;
+BEGIN
+ IF input_text IS NULL OR BTRIM(input_text) = '' THEN
+ RETURN '{}'::jsonb;
+ END IF;
+
+ BEGIN
+ parsed := input_text::jsonb;
+ EXCEPTION
+ WHEN OTHERS THEN
+ RETURN '{}'::jsonb;
+ END;
+
+ IF jsonb_typeof(parsed) = 'object' THEN
+ RETURN parsed;
+ END IF;
+
+ RETURN jsonb_build_object('_legacy_metadata_raw_json', parsed);
+END;
+$$;
+
+CREATE OR REPLACE FUNCTION public.__migration_116_is_valid_legacy_metadata_jsonb(input_text TEXT)
+RETURNS BOOLEAN
+LANGUAGE plpgsql
+AS $$
+DECLARE
+ parsed JSONB;
+BEGIN
+ IF input_text IS NULL OR BTRIM(input_text) = '' THEN
+ RETURN TRUE;
+ END IF;
+
+ parsed := input_text::jsonb;
+ RETURN TRUE;
+EXCEPTION
+ WHEN OTHERS THEN
+ RETURN FALSE;
+END;
+$$;
+
+DO $$
+BEGIN
+ IF to_regclass('public.user_external_identities') IS NULL THEN
+ RETURN;
+ END IF;
+
+ EXECUTE $sql$
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'legacy_external_identity_invalid_metadata_json',
+ 'legacy_external_identity:' || uei.id::text,
+ jsonb_build_object(
+ 'legacy_identity_id', uei.id,
+ 'user_id', uei.user_id,
+ 'provider', LOWER(BTRIM(COALESCE(uei.provider, ''))),
+ 'provider_user_id', BTRIM(COALESCE(uei.provider_user_id, '')),
+ 'provider_union_id', BTRIM(COALESCE(uei.provider_union_id, '')),
+ 'reason', 'legacy metadata is not valid JSON; migration downgraded metadata to empty object',
+ 'raw_metadata', LEFT(BTRIM(COALESCE(uei.metadata, '')), 1000),
+ 'migration', '116_auth_identity_legacy_external_safety_reports'
+ )
+FROM user_external_identities AS uei
+JOIN users AS u ON u.id = uei.user_id
+WHERE u.deleted_at IS NULL
+ AND BTRIM(COALESCE(uei.metadata, '')) <> ''
+ AND NOT public.__migration_116_is_valid_legacy_metadata_jsonb(uei.metadata)
+ON CONFLICT (report_type, report_key) DO NOTHING;
+$sql$;
+
+ EXECUTE $sql$
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'legacy_external_identity_conflict',
+ 'legacy_external_identity:' || legacy.id::text,
+ legacy.metadata_json || jsonb_build_object(
+ 'legacy_identity_id', legacy.id,
+ 'legacy_user_id', legacy.user_id,
+ 'existing_identity_id', ai.id,
+ 'existing_user_id', ai.user_id,
+ 'provider_type', legacy.provider_type,
+ 'provider_key', legacy.provider_key,
+ 'provider_subject', legacy.provider_subject,
+ 'reason', 'legacy canonical identity subject already belongs to another user',
+ 'migration', '116_auth_identity_legacy_external_safety_reports'
+ )
+FROM (
+ SELECT
+ uei.id,
+ uei.user_id,
+ LOWER(BTRIM(COALESCE(uei.provider, ''))) AS provider_type,
+ CASE
+ WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN 'wechat-main'
+ ELSE 'linuxdo'
+ END AS provider_key,
+ CASE
+ WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN BTRIM(COALESCE(uei.provider_union_id, ''))
+ ELSE BTRIM(COALESCE(uei.provider_user_id, ''))
+ END AS provider_subject,
+ BTRIM(COALESCE(uei.provider_user_id, '')) AS provider_user_id,
+ BTRIM(COALESCE(uei.provider_union_id, '')) AS provider_union_id,
+ BTRIM(COALESCE(uei.provider_username, '')) AS provider_username,
+ BTRIM(COALESCE(uei.display_name, '')) AS display_name,
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) IN ('linuxdo', 'wechat')
+ AND (
+ (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '')
+ OR
+ (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '')
+ )
+) AS legacy
+JOIN auth_identities AS ai
+ ON ai.provider_type = legacy.provider_type
+ AND ai.provider_key = legacy.provider_key
+ AND ai.provider_subject = legacy.provider_subject
+WHERE ai.user_id <> legacy.user_id
+ON CONFLICT (report_type, report_key) DO NOTHING;
+$sql$;
+
+ EXECUTE $sql$
+INSERT INTO auth_identities (
+ user_id,
+ provider_type,
+ provider_key,
+ provider_subject,
+ verified_at,
+ metadata
+)
+SELECT
+ legacy.user_id,
+ legacy.provider_type,
+ legacy.provider_key,
+ legacy.provider_subject,
+ legacy.verified_at,
+ legacy.metadata_json || jsonb_build_object(
+ 'legacy_identity_id', legacy.id,
+ 'provider_user_id', legacy.provider_user_id,
+ 'provider_union_id', NULLIF(legacy.provider_union_id, ''),
+ 'provider_username', legacy.provider_username,
+ 'display_name', legacy.display_name,
+ 'migration', '116_auth_identity_legacy_external_safety_reports'
+ )
+FROM (
+ SELECT
+ uei.id,
+ uei.user_id,
+ LOWER(BTRIM(COALESCE(uei.provider, ''))) AS provider_type,
+ CASE
+ WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN 'wechat-main'
+ ELSE 'linuxdo'
+ END AS provider_key,
+ CASE
+ WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN BTRIM(COALESCE(uei.provider_union_id, ''))
+ ELSE BTRIM(COALESCE(uei.provider_user_id, ''))
+ END AS provider_subject,
+ BTRIM(COALESCE(uei.provider_user_id, '')) AS provider_user_id,
+ BTRIM(COALESCE(uei.provider_union_id, '')) AS provider_union_id,
+ BTRIM(COALESCE(uei.provider_username, '')) AS provider_username,
+ BTRIM(COALESCE(uei.display_name, '')) AS display_name,
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json,
+ COALESCE(uei.updated_at, uei.created_at, NOW()) AS verified_at
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) IN ('linuxdo', 'wechat')
+ AND (
+ (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '')
+ OR
+ (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '')
+ )
+) AS legacy
+LEFT JOIN auth_identities AS ai
+ ON ai.provider_type = legacy.provider_type
+ AND ai.provider_key = legacy.provider_key
+ AND ai.provider_subject = legacy.provider_subject
+WHERE ai.id IS NULL
+ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
+$sql$;
+
+ EXECUTE $sql$
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'legacy_external_channel_conflict',
+ 'legacy_external_identity:' || legacy.id::text,
+ legacy.metadata_json || jsonb_build_object(
+ 'legacy_identity_id', legacy.id,
+ 'legacy_user_id', legacy.user_id,
+ 'existing_channel_id', channel.id,
+ 'existing_identity_id', existing_ai.id,
+ 'existing_user_id', existing_ai.user_id,
+ 'provider_type', 'wechat',
+ 'provider_key', 'wechat-main',
+ 'provider_subject', legacy.provider_union_id,
+ 'channel', legacy.channel,
+ 'channel_app_id', legacy.channel_app_id,
+ 'channel_subject', legacy.provider_user_id,
+ 'reason', 'legacy channel subject already belongs to another user',
+ 'migration', '116_auth_identity_legacy_external_safety_reports'
+ )
+FROM (
+ SELECT
+ uei.id,
+ uei.user_id,
+ BTRIM(COALESCE(uei.provider_user_id, '')) AS provider_user_id,
+ BTRIM(COALESCE(uei.provider_union_id, '')) AS provider_union_id,
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json,
+ BTRIM(COALESCE(public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel', '')) AS channel,
+ BTRIM(COALESCE(
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel_app_id',
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'appid',
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'app_id',
+ ''
+ )) AS channel_app_id
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
+ AND BTRIM(COALESCE(uei.provider_union_id, '')) <> ''
+ AND BTRIM(COALESCE(uei.provider_user_id, '')) <> ''
+) AS legacy
+JOIN auth_identities AS legacy_ai
+ ON legacy_ai.user_id = legacy.user_id
+ AND legacy_ai.provider_type = 'wechat'
+ AND legacy_ai.provider_key = 'wechat-main'
+ AND legacy_ai.provider_subject = legacy.provider_union_id
+JOIN auth_identity_channels AS channel
+ ON channel.provider_type = 'wechat'
+ AND channel.provider_key = 'wechat-main'
+ AND channel.channel = legacy.channel
+ AND channel.channel_app_id = legacy.channel_app_id
+ AND channel.channel_subject = legacy.provider_user_id
+JOIN auth_identities AS existing_ai
+ ON existing_ai.id = channel.identity_id
+WHERE legacy.channel <> ''
+ AND legacy.channel_app_id <> ''
+ AND existing_ai.user_id <> legacy.user_id
+ON CONFLICT (report_type, report_key) DO NOTHING;
+$sql$;
+
+ EXECUTE $sql$
+INSERT INTO auth_identity_channels (
+ identity_id,
+ provider_type,
+ provider_key,
+ channel,
+ channel_app_id,
+ channel_subject,
+ metadata
+)
+SELECT
+ legacy_ai.id,
+ 'wechat',
+ 'wechat-main',
+ legacy.channel,
+ legacy.channel_app_id,
+ legacy.provider_user_id,
+ legacy.metadata_json || jsonb_build_object(
+ 'openid', legacy.provider_user_id,
+ 'unionid', legacy.provider_union_id,
+ 'migration', '116_auth_identity_legacy_external_safety_reports'
+ )
+FROM (
+ SELECT
+ uei.user_id,
+ BTRIM(COALESCE(uei.provider_user_id, '')) AS provider_user_id,
+ BTRIM(COALESCE(uei.provider_union_id, '')) AS provider_union_id,
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json,
+ BTRIM(COALESCE(public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel', '')) AS channel,
+ BTRIM(COALESCE(
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel_app_id',
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'appid',
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'app_id',
+ ''
+ )) AS channel_app_id
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
+ AND BTRIM(COALESCE(uei.provider_union_id, '')) <> ''
+ AND BTRIM(COALESCE(uei.provider_user_id, '')) <> ''
+) AS legacy
+JOIN auth_identities AS legacy_ai
+ ON legacy_ai.user_id = legacy.user_id
+ AND legacy_ai.provider_type = 'wechat'
+ AND legacy_ai.provider_key = 'wechat-main'
+ AND legacy_ai.provider_subject = legacy.provider_union_id
+LEFT JOIN auth_identity_channels AS channel
+ ON channel.provider_type = 'wechat'
+ AND channel.provider_key = 'wechat-main'
+ AND channel.channel = legacy.channel
+ AND channel.channel_app_id = legacy.channel_app_id
+ AND channel.channel_subject = legacy.provider_user_id
+WHERE legacy.channel <> ''
+ AND legacy.channel_app_id <> ''
+ AND channel.id IS NULL
+ON CONFLICT DO NOTHING;
+$sql$;
+
+ EXECUTE $sql$
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'wechat_openid_only_requires_remediation',
+ 'legacy_external_identity:' || legacy.id::text,
+ legacy.metadata_json || jsonb_build_object(
+ 'legacy_identity_id', legacy.id,
+ 'user_id', legacy.user_id,
+ 'openid', legacy.provider_user_id,
+ 'reason', 'legacy user_external_identities row only has openid and cannot be canonicalized offline',
+ 'migration', '116_auth_identity_legacy_external_safety_reports'
+ )
+FROM (
+ SELECT
+ uei.id,
+ uei.user_id,
+ BTRIM(COALESCE(uei.provider_user_id, '')) AS provider_user_id,
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
+ AND BTRIM(COALESCE(uei.provider_user_id, '')) <> ''
+ AND BTRIM(COALESCE(uei.provider_union_id, '')) = ''
+) AS legacy
+ON CONFLICT (report_type, report_key) DO NOTHING;
+$sql$;
+END $$;
+
+DROP FUNCTION IF EXISTS public.__migration_116_is_valid_legacy_metadata_jsonb(TEXT);
+DROP FUNCTION IF EXISTS public.__migration_116_safe_legacy_metadata_jsonb(TEXT);
--
GitLab
From ea27ac6fd7d21ec5b6ce6b861b3d7889293624d3 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 01:00:59 +0800
Subject: [PATCH 107/261] fix: unify email identity sync and retry first-bind
defaults
---
backend/internal/repository/user_repo.go | 8 --
backend/internal/service/admin_service.go | 7 --
.../admin_service_email_identity_sync_test.go | 28 +++----
backend/internal/service/auth_service.go | 81 ++++++++++++++++---
.../auth_service_identity_sync_test.go | 68 ++++++++++++++++
backend/internal/service/user_service.go | 31 -------
.../user_service_email_identity_sync_test.go | 9 +--
7 files changed, 154 insertions(+), 78 deletions(-)
diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go
index 25d3f1d6..195776a3 100644
--- a/backend/internal/repository/user_repo.go
+++ b/backend/internal/repository/user_repo.go
@@ -209,14 +209,6 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
return nil
}
-func (r *userRepository) EnsureEmailAuthIdentity(ctx context.Context, userID int64, email string) error {
- return ensureEmailAuthIdentityWithClient(ctx, r.client, userID, email, "service_dual_write")
-}
-
-func (r *userRepository) ReplaceEmailAuthIdentity(ctx context.Context, userID int64, oldEmail, newEmail string) error {
- return replaceEmailAuthIdentityWithClient(ctx, r.client, userID, oldEmail, newEmail, "service_dual_write")
-}
-
func ensureEmailAuthIdentityWithClient(ctx context.Context, client *dbent.Client, userID int64, email string, source string) error {
client = clientFromContext(ctx, client)
if client == nil || userID <= 0 {
diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go
index 10b85f76..ce1c1a77 100644
--- a/backend/internal/service/admin_service.go
+++ b/backend/internal/service/admin_service.go
@@ -650,9 +650,6 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu
if err := s.userRepo.Create(ctx, user); err != nil {
return nil, err
}
- if err := ensureEmailAuthIdentitySync(ctx, s.userRepo, user.ID, user.Email); err != nil {
- return nil, fmt.Errorf("sync email auth identity: %w", err)
- }
s.assignDefaultSubscriptions(ctx, user.ID)
return user, nil
}
@@ -688,7 +685,6 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
oldConcurrency := user.Concurrency
oldStatus := user.Status
oldRole := user.Role
- oldEmail := user.Email
if input.Email != "" {
user.Email = input.Email
@@ -721,9 +717,6 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, err
}
- if err := replaceEmailAuthIdentitySync(ctx, s.userRepo, user.ID, oldEmail, user.Email); err != nil {
- return nil, fmt.Errorf("sync email auth identity: %w", err)
- }
// 同步用户专属分组倍率
if input.GroupRates != nil && s.userGroupRateRepo != nil {
diff --git a/backend/internal/service/admin_service_email_identity_sync_test.go b/backend/internal/service/admin_service_email_identity_sync_test.go
index d555d609..d6a7af9a 100644
--- a/backend/internal/service/admin_service_email_identity_sync_test.go
+++ b/backend/internal/service/admin_service_email_identity_sync_test.go
@@ -31,6 +31,8 @@ type emailSyncRepoStub struct {
updated []*User
ensureCalls []ensureEmailCall
replaceCalls []replaceEmailCall
+ ensureErr error
+ replaceErr error
}
func (s *emailSyncRepoStub) Create(_ context.Context, user *User) error {
@@ -125,7 +127,7 @@ func (s *emailSyncRepoStub) DisableTotp(context.Context, int64) error { return n
func (s *emailSyncRepoStub) EnsureEmailAuthIdentity(_ context.Context, userID int64, email string) error {
s.ensureCalls = append(s.ensureCalls, ensureEmailCall{userID: userID, email: email})
- return nil
+ return s.ensureErr
}
func (s *emailSyncRepoStub) ReplaceEmailAuthIdentity(_ context.Context, userID int64, oldEmail, newEmail string) error {
@@ -134,11 +136,14 @@ func (s *emailSyncRepoStub) ReplaceEmailAuthIdentity(_ context.Context, userID i
oldEmail: oldEmail,
newEmail: newEmail,
})
- return nil
+ return s.replaceErr
}
-func TestAdminService_CreateUser_EnsuresEmailAuthIdentity(t *testing.T) {
- repo := &emailSyncRepoStub{nextID: 55}
+func TestAdminService_CreateUser_DoesNotReturnPartialSuccessFromEmailIdentityResync(t *testing.T) {
+ repo := &emailSyncRepoStub{
+ nextID: 55,
+ ensureErr: fmt.Errorf("unexpected email resync"),
+ }
svc := &adminServiceImpl{userRepo: repo}
user, err := svc.CreateUser(context.Background(), &CreateUserInput{
@@ -147,14 +152,12 @@ func TestAdminService_CreateUser_EnsuresEmailAuthIdentity(t *testing.T) {
})
require.NoError(t, err)
require.NotNil(t, user)
- require.Equal(t, []ensureEmailCall{{
- userID: 55,
- email: "admin-created@example.com",
- }}, repo.ensureCalls)
+ require.Equal(t, int64(55), user.ID)
+ require.Empty(t, repo.ensureCalls)
require.Empty(t, repo.replaceCalls)
}
-func TestAdminService_UpdateUser_ReplacesEmailAuthIdentity(t *testing.T) {
+func TestAdminService_UpdateUser_DoesNotReturnPartialSuccessFromEmailIdentityResync(t *testing.T) {
repo := &emailSyncRepoStub{
user: &User{
ID: 91,
@@ -163,6 +166,7 @@ func TestAdminService_UpdateUser_ReplacesEmailAuthIdentity(t *testing.T) {
Status: StatusActive,
Concurrency: 3,
},
+ replaceErr: fmt.Errorf("unexpected email resync"),
}
svc := &adminServiceImpl{userRepo: repo}
@@ -172,10 +176,6 @@ func TestAdminService_UpdateUser_ReplacesEmailAuthIdentity(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, updated)
require.Equal(t, "after@example.com", updated.Email)
- require.Equal(t, []replaceEmailCall{{
- userID: 91,
- oldEmail: "before@example.com",
- newEmail: "after@example.com",
- }}, repo.replaceCalls)
+ require.Empty(t, repo.replaceCalls)
require.Empty(t, repo.ensureCalls)
}
diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go
index d0d5e4e3..00fefd82 100644
--- a/backend/internal/service/auth_service.go
+++ b/backend/internal/service/auth_service.go
@@ -768,9 +768,6 @@ func (s *AuthService) postAuthUserBootstrap(ctx context.Context, user *User, sig
}
s.updateUserSignupSource(ctx, user.ID, signupSource)
- if signupSource == "email" {
- s.ensureEmailAuthIdentity(ctx, user)
- }
if touchLogin {
s.touchUserLogin(ctx, user.ID)
}
@@ -807,21 +804,81 @@ func (s *AuthService) backfillEmailIdentityOnSuccessfulLogin(ctx context.Context
if s == nil || user == nil || user.ID <= 0 {
return
}
- if s.ensureEmailAuthIdentity(ctx, user) {
+ identity, created := s.ensureEmailAuthIdentity(ctx, user)
+ if s.shouldApplyEmailFirstBindDefaults(ctx, user.ID, identity, created) {
if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, user.ID, "email"); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to apply email first bind defaults: user_id=%d err=%v", user.ID, err)
}
}
}
-func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) bool {
- if s == nil || s.entClient == nil || user == nil || user.ID <= 0 {
+func (s *AuthService) shouldApplyEmailFirstBindDefaults(
+ ctx context.Context,
+ userID int64,
+ identity *dbent.AuthIdentity,
+ created bool,
+) bool {
+ if created {
+ return true
+ }
+ if s == nil || s.entClient == nil || userID <= 0 || identity == nil || identity.UserID != userID {
+ return false
+ }
+ if emailAuthIdentitySource(identity.Metadata) != "auth_service_dual_write" {
+ return false
+ }
+
+ hasGrant, err := s.hasProviderGrantRecord(ctx, userID, "email", "first_bind")
+ if err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to inspect email first bind grant state: user_id=%d err=%v", userID, err)
return false
}
+ return !hasGrant
+}
+
+func emailAuthIdentitySource(metadata map[string]any) string {
+ if len(metadata) == 0 {
+ return ""
+ }
+ raw, ok := metadata["source"]
+ if !ok {
+ return ""
+ }
+ return strings.TrimSpace(fmt.Sprint(raw))
+}
+
+func (s *AuthService) hasProviderGrantRecord(
+ ctx context.Context,
+ userID int64,
+ providerType string,
+ grantReason string,
+) (bool, error) {
+ if s == nil || s.entClient == nil || userID <= 0 {
+ return false, nil
+ }
+
+ rows, err := s.entClient.QueryContext(
+ ctx,
+ `SELECT 1 FROM user_provider_default_grants WHERE user_id = ? AND provider_type = ? AND grant_reason = ? LIMIT 1`,
+ userID,
+ strings.TrimSpace(providerType),
+ strings.TrimSpace(grantReason),
+ )
+ if err != nil {
+ return false, err
+ }
+ defer rows.Close()
+ return rows.Next(), rows.Err()
+}
+
+func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) (*dbent.AuthIdentity, bool) {
+ if s == nil || s.entClient == nil || user == nil || user.ID <= 0 {
+ return nil, false
+ }
email := strings.ToLower(strings.TrimSpace(user.Email))
if email == "" || isReservedEmail(email) {
- return false
+ return nil, false
}
client := s.entClient
@@ -840,7 +897,7 @@ func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) b
existed, err := buildQuery().Exist(ctx)
if err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to inspect email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
- return false
+ return nil, false
}
if !existed {
@@ -861,21 +918,21 @@ func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) b
DoNothing().
Exec(ctx); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to ensure email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
- return false
+ return nil, false
}
}
identity, err := buildQuery().Only(ctx)
if err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to reload email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
- return false
+ return nil, false
}
if identity.UserID != user.ID {
logger.LegacyPrintf("service.auth", "[Auth] Email auth identity ownership mismatch: user_id=%d email=%s owner_id=%d", user.ID, email, identity.UserID)
- return false
+ return nil, false
}
- return !existed
+ return identity, !existed
}
func inferLegacySignupSource(email string) string {
diff --git a/backend/internal/service/auth_service_identity_sync_test.go b/backend/internal/service/auth_service_identity_sync_test.go
index e2a94b13..95c9c933 100644
--- a/backend/internal/service/auth_service_identity_sync_test.go
+++ b/backend/internal/service/auth_service_identity_sync_test.go
@@ -5,6 +5,7 @@ package service_test
import (
"context"
"database/sql"
+ "errors"
"testing"
"time"
@@ -34,6 +35,24 @@ func (s *authIdentityDefaultSubAssignerStub) AssignOrExtendSubscription(
return &service.UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, true, nil
}
+type flakyAuthIdentityDefaultSubAssignerStub struct {
+ failuresRemaining int
+ calls []*service.AssignSubscriptionInput
+}
+
+func (s *flakyAuthIdentityDefaultSubAssignerStub) AssignOrExtendSubscription(
+ _ context.Context,
+ input *service.AssignSubscriptionInput,
+) (*service.UserSubscription, bool, error) {
+ cloned := *input
+ s.calls = append(s.calls, &cloned)
+ if s.failuresRemaining > 0 {
+ s.failuresRemaining--
+ return nil, false, errors.New("temporary assign failure")
+ }
+ return &service.UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, true, nil
+}
+
type authIdentitySettingRepoStub struct {
values map[string]string
}
@@ -333,6 +352,55 @@ func TestAuthServiceLogin_DoesNotApplyEmailFirstBindDefaultsWhenIdentityAlreadyE
require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
}
+func TestAuthServiceLogin_RetriesEmailFirstBindDefaultsAfterPreviousFailure(t *testing.T) {
+ assigner := &flakyAuthIdentityDefaultSubAssignerStub{failuresRemaining: 1}
+ svc, _, client := newAuthServiceWithEnt(t, map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
+ service.SettingKeyAuthSourceDefaultEmailConcurrency: "4",
+ service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
+ service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
+ }, assigner)
+ ctx := context.Background()
+
+ passwordHash, err := svc.HashPassword("password")
+ require.NoError(t, err)
+ user, err := client.User.Create().
+ SetEmail("retry-first-bind@example.com").
+ SetUsername("retry-user").
+ SetPasswordHash(passwordHash).
+ SetBalance(1.5).
+ SetConcurrency(2).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ token, gotUser, err := svc.Login(ctx, user.Email, "password")
+ require.NoError(t, err)
+ require.NotEmpty(t, token)
+ require.NotNil(t, gotUser)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, 1.5, storedUser.Balance)
+ require.Equal(t, 2, storedUser.Concurrency)
+ require.Len(t, assigner.calls, 1)
+ require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
+
+ token, gotUser, err = svc.Login(ctx, user.Email, "password")
+ require.NoError(t, err)
+ require.NotEmpty(t, token)
+ require.NotNil(t, gotUser)
+
+ storedUser, err = client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, 10.0, storedUser.Balance)
+ require.Equal(t, 6, storedUser.Concurrency)
+ require.Len(t, assigner.calls, 2)
+ require.Equal(t, 1, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
+}
+
func countProviderGrantRecords(
t *testing.T,
client *dbent.Client,
diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go
index cd1bc2bb..7c2ca2d0 100644
--- a/backend/internal/service/user_service.go
+++ b/backend/internal/service/user_service.go
@@ -161,33 +161,6 @@ type userAuthIdentityReader interface {
ListUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error)
}
-type emailAuthIdentitySynchronizer interface {
- EnsureEmailAuthIdentity(ctx context.Context, userID int64, email string) error
- ReplaceEmailAuthIdentity(ctx context.Context, userID int64, oldEmail, newEmail string) error
-}
-
-func ensureEmailAuthIdentitySync(ctx context.Context, repo UserRepository, userID int64, email string) error {
- syncer, ok := repo.(emailAuthIdentitySynchronizer)
- if !ok {
- return nil
- }
- return syncer.EnsureEmailAuthIdentity(ctx, userID, email)
-}
-
-func replaceEmailAuthIdentitySync(ctx context.Context, repo UserRepository, userID int64, oldEmail, newEmail string) error {
- oldNormalized := strings.ToLower(strings.TrimSpace(oldEmail))
- newNormalized := strings.ToLower(strings.TrimSpace(newEmail))
- if oldNormalized == newNormalized {
- return nil
- }
-
- syncer, ok := repo.(emailAuthIdentitySynchronizer)
- if !ok {
- return nil
- }
- return syncer.ReplaceEmailAuthIdentity(ctx, userID, oldEmail, newEmail)
-}
-
// ChangePasswordRequest 修改密码请求
type ChangePasswordRequest struct {
CurrentPassword string `json:"current_password"`
@@ -281,7 +254,6 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
return nil, fmt.Errorf("get user: %w", err)
}
oldConcurrency := user.Concurrency
- oldEmail := user.Email
// 更新字段
if req.Email != nil {
@@ -326,9 +298,6 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, fmt.Errorf("update user: %w", err)
}
- if err := replaceEmailAuthIdentitySync(ctx, s.userRepo, user.ID, oldEmail, user.Email); err != nil {
- return nil, fmt.Errorf("sync email auth identity: %w", err)
- }
if s.authCacheInvalidator != nil && user.Concurrency != oldConcurrency {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
diff --git a/backend/internal/service/user_service_email_identity_sync_test.go b/backend/internal/service/user_service_email_identity_sync_test.go
index 8109b368..702b3b1a 100644
--- a/backend/internal/service/user_service_email_identity_sync_test.go
+++ b/backend/internal/service/user_service_email_identity_sync_test.go
@@ -9,7 +9,7 @@ import (
"github.com/stretchr/testify/require"
)
-func TestUpdateProfile_ReplacesEmailAuthIdentityWhenEmailChanges(t *testing.T) {
+func TestUpdateProfile_DoesNotReturnPartialSuccessFromEmailIdentityResync(t *testing.T) {
repo := &emailSyncRepoStub{
user: &User{
ID: 19,
@@ -17,6 +17,7 @@ func TestUpdateProfile_ReplacesEmailAuthIdentityWhenEmailChanges(t *testing.T) {
Username: "tester",
Concurrency: 2,
},
+ replaceErr: context.DeadlineExceeded,
}
svc := NewUserService(repo, nil, nil, nil)
@@ -28,10 +29,6 @@ func TestUpdateProfile_ReplacesEmailAuthIdentityWhenEmailChanges(t *testing.T) {
require.NotNil(t, updated)
require.Equal(t, newEmail, updated.Email)
require.Equal(t, 1, repo.updateCalls)
- require.Equal(t, []replaceEmailCall{{
- userID: 19,
- oldEmail: "profile-before@example.com",
- newEmail: "profile-after@example.com",
- }}, repo.replaceCalls)
+ require.Empty(t, repo.replaceCalls)
require.Empty(t, repo.ensureCalls)
}
--
GitLab
From 365ef1fdf79b51f3aec991d335311a201f5aa116 Mon Sep 17 00:00:00 2001
From: erio
Date: Tue, 21 Apr 2026 01:05:14 +0800
Subject: [PATCH 108/261] refactor(channels): consolidate pricing index,
tighten types, polish DTOs
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Follow-up to the available-channels review pass. No behavior change for
end users; tightens internals based on three independent code reviews.
Backend
- service/channel.go: collapse buildPricingLookup + pricedNamesFor
into a single platformPricingIndex (byLower + originalCase + ordered
names), built once per SupportedModels call. Fixes a casing-
consistency bug where the same logical model appeared with mapping
case in the exact branch but pricing case in the wildcard branch —
pricing's original case now wins everywhere.
- service/channel.go: doc that a mapping key of just "*" expands to
every priced model on the platform (intentional "passthrough all").
- service/channel_available.go: normalize empty BillingModelSource to
channel_mapped at construction time, removing the same fallback
duplicated in the admin DTO mapper and the admin Vue template.
- handler/admin/available_channel_handler.go: unexport
availableChannelToAdminResponse (same-package usage only); mapper
is now a pure passthrough.
- handler/available_channel_handler.go: drop the middleware2 alias
(no name collision in this file).
Frontend
- utils/pricing.ts: extract formatScaled, used by SupportedModelChip
and PricingRow.
- api/admin/channels.ts: re-export BillingMode from constants/channel;
tighten Channel.status / billing_model_source to ChannelStatus /
BillingModelSource (and same for AvailableChannel).
- components/channels/AvailableChannelsTable.vue: drop dead
withDefaults wrapper (loading is required, both call sites pass it).
- views/admin/AvailableChannelsView.vue: drop the redundant
|| BILLING_MODEL_SOURCE_CHANNEL_MAPPED fallback (now applied in
service layer); remove unused import.
- i18n zh + en: delete unused tierLabel and tokenRange keys from
both availableChannels.pricing and admin.availableChannels.pricing.
Tests
- New: SupportedModels_ExactKeyUsesPricedCaseWhenAvailable locks the
pricing-case-wins rule.
- New: SupportedModels_AsteriskOnlyMappingExpandsAllPriced documents
the "*" expansion rule.
- Admin handler: existing tests adjusted to pass an explicit
BillingModelSource (default-fill is now exercised by service tests).
---
.../admin/available_channel_handler.go | 14 +--
.../admin/available_channel_handler_test.go | 8 +-
.../handler/available_channel_handler.go | 4 +-
backend/internal/service/channel.go | 96 +++++++++----------
backend/internal/service/channel_available.go | 7 +-
backend/internal/service/channel_test.go | 31 ++++++
frontend/src/api/admin/channels.ts | 11 ++-
.../channels/AvailableChannelsTable.vue | 21 ++--
.../src/components/channels/PricingRow.vue | 6 +-
.../channels/SupportedModelChip.vue | 6 +-
frontend/src/i18n/locales/en.ts | 4 -
frontend/src/i18n/locales/zh.ts | 4 -
frontend/src/utils/pricing.ts | 13 +++
.../src/views/admin/AvailableChannelsView.vue | 11 +--
14 files changed, 128 insertions(+), 108 deletions(-)
create mode 100644 frontend/src/utils/pricing.ts
diff --git a/backend/internal/handler/admin/available_channel_handler.go b/backend/internal/handler/admin/available_channel_handler.go
index 53776105..45b8f357 100644
--- a/backend/internal/handler/admin/available_channel_handler.go
+++ b/backend/internal/handler/admin/available_channel_handler.go
@@ -45,9 +45,9 @@ type availableChannelResponse struct {
SupportedModels []supportedModelResponse `json:"supported_models"`
}
-// AvailableChannelToAdminResponse 将 service 层的 AvailableChannel 转为管理员 DTO。
-// 导出供同 package 的复用;也用于构造测试 fixture。
-func AvailableChannelToAdminResponse(ch service.AvailableChannel) availableChannelResponse {
+// availableChannelToAdminResponse 将 service 层的 AvailableChannel 转为管理员 DTO。
+// 同 package 内复用;也用于构造测试 fixture。
+func availableChannelToAdminResponse(ch service.AvailableChannel) availableChannelResponse {
groups := make([]availableGroupResponse, 0, len(ch.Groups))
for _, g := range ch.Groups {
groups = append(groups, availableGroupResponse{ID: g.ID, Name: g.Name, Platform: g.Platform})
@@ -66,16 +66,12 @@ func AvailableChannelToAdminResponse(ch service.AvailableChannel) availableChann
Pricing: pricing,
})
}
- billingSource := ch.BillingModelSource
- if billingSource == "" {
- billingSource = service.BillingModelSourceChannelMapped
- }
return availableChannelResponse{
ID: ch.ID,
Name: ch.Name,
Description: ch.Description,
Status: ch.Status,
- BillingModelSource: billingSource,
+ BillingModelSource: ch.BillingModelSource,
RestrictModels: ch.RestrictModels,
Groups: groups,
SupportedModels: models,
@@ -93,7 +89,7 @@ func (h *AvailableChannelHandler) List(c *gin.Context) {
out := make([]availableChannelResponse, 0, len(channels))
for _, ch := range channels {
- out = append(out, AvailableChannelToAdminResponse(ch))
+ out = append(out, availableChannelToAdminResponse(ch))
}
response.Success(c, gin.H{"items": out})
}
diff --git a/backend/internal/handler/admin/available_channel_handler_test.go b/backend/internal/handler/admin/available_channel_handler_test.go
index 687e8dad..7d249383 100644
--- a/backend/internal/handler/admin/available_channel_handler_test.go
+++ b/backend/internal/handler/admin/available_channel_handler_test.go
@@ -12,13 +12,13 @@ import (
func TestAvailableChannelToAdminResponse_IncludesFullDTO(t *testing.T) {
// 管理员视图应包含 id / status / billing_model_source / restrict_models 等
- // 管理字段;BillingModelSource 为空时应默认回填 channel_mapped。
+ // 管理字段;mapper 是纯透传,BillingModelSource 的默认回填由 service 层负责。
input := service.AvailableChannel{
ID: 42,
Name: "ch",
Description: "d",
Status: service.StatusActive,
- BillingModelSource: "", // 验证默认值填充
+ BillingModelSource: service.BillingModelSourceChannelMapped,
RestrictModels: true,
Groups: []service.AvailableGroupRef{
{ID: 1, Name: "g1", Platform: "anthropic"},
@@ -28,7 +28,7 @@ func TestAvailableChannelToAdminResponse_IncludesFullDTO(t *testing.T) {
},
}
- resp := AvailableChannelToAdminResponse(input)
+ resp := availableChannelToAdminResponse(input)
require.Equal(t, int64(42), resp.ID)
require.Equal(t, "ch", resp.Name)
require.Equal(t, service.StatusActive, resp.Status)
@@ -52,6 +52,6 @@ func TestAvailableChannelToAdminResponse_PreservesExplicitBillingSource(t *testi
input := service.AvailableChannel{
BillingModelSource: service.BillingModelSourceUpstream,
}
- resp := AvailableChannelToAdminResponse(input)
+ resp := availableChannelToAdminResponse(input)
require.Equal(t, service.BillingModelSourceUpstream, resp.BillingModelSource)
}
diff --git a/backend/internal/handler/available_channel_handler.go b/backend/internal/handler/available_channel_handler.go
index 25452fc8..d19fa9b6 100644
--- a/backend/internal/handler/available_channel_handler.go
+++ b/backend/internal/handler/available_channel_handler.go
@@ -2,7 +2,7 @@ package handler
import (
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
- middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -83,7 +83,7 @@ type userAvailableChannel struct {
// List 列出当前用户可见的「可用渠道」。
// GET /api/v1/channels/available
func (h *AvailableChannelHandler) List(c *gin.Context) {
- subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
diff --git a/backend/internal/service/channel.go b/backend/internal/service/channel.go
index de31e829..d142146d 100644
--- a/backend/internal/service/channel.go
+++ b/backend/internal/service/channel.go
@@ -391,59 +391,50 @@ func (c *Channel) GetModelPricingByPlatform(platform, model string) *ChannelMode
return nil
}
-// pricingLookup 是渠道定价在单个计算过程中的索引:platform → (lowerName → *pricing)。
-// 用于将 SupportedModels 的定价解析从 O(N*M) 降到 O(N+M)。
-type pricingLookup map[string]map[string]*ChannelModelPricing
+// platformPricingIndex 是单个平台下定价信息的复合索引。
+// 一次扫描即可同时支持精确查找(exact 分支)与有序遍历(wildcard 分支),
+// 避免 SupportedModels 对每个平台重复扫描定价列表。
+//
+// byLower 与 names/originalCase 共享同一套去重规则:以 lower-case 模型名为 key,
+// 首个命中保留其原始大小写。names 维持按定价行扫描顺序的稳定迭代。
+type platformPricingIndex struct {
+ byLower map[string]*ChannelModelPricing // lowercased model name → pricing (Clone'd)
+ originalCase map[string]string // lowercased model name → original-case model name
+ names []string // priced model names in their ORIGINAL case, insertion-ordered, deduped case-insensitively (first wins)
+}
-// buildPricingLookup 对渠道的定价列表做一次扫描,生成 platform+模型名 的索引。
+// buildPricingIndex 对渠道的定价列表做一次扫描,按 platform 聚合为查找索引。
// 索引值是定价条目的 Clone 指针,调用方可安全按需返回副本而不污染缓存。
-// wildcard 后缀(如 "claude-*")不会被索引(它们不是精确模型名)。
-func buildPricingLookup(pricings []ChannelModelPricing) pricingLookup {
- lookup := make(pricingLookup, len(pricings))
+// 通配符后缀条目(如 "claude-*")不被索引(它们是模式,不是具体模型名)。
+// 同一平台中以大小写不敏感方式去重,先出现者保留原始大小写。
+func buildPricingIndex(pricings []ChannelModelPricing) map[string]*platformPricingIndex {
+ idx := make(map[string]*platformPricingIndex)
for i := range pricings {
p := pricings[i]
- byModel, ok := lookup[p.Platform]
+ pidx, ok := idx[p.Platform]
if !ok {
- byModel = make(map[string]*ChannelModelPricing, len(p.Models))
- lookup[p.Platform] = byModel
+ pidx = &platformPricingIndex{
+ byLower: make(map[string]*ChannelModelPricing),
+ originalCase: make(map[string]string),
+ names: make([]string, 0),
+ }
+ idx[p.Platform] = pidx
}
for _, m := range p.Models {
if _, wild := splitWildcardSuffix(m); wild {
continue
}
lower := strings.ToLower(m)
- if _, exists := byModel[lower]; exists {
- continue // 首个命中胜出(保持 case-insensitive 去重后第一个定价)
+ if _, exists := pidx.byLower[lower]; exists {
+ continue // 首个命中胜出(case-insensitive 去重后第一个定价 / 第一个原始大小写)
}
cp := pricings[i].Clone()
- byModel[lower] = &cp
+ pidx.byLower[lower] = &cp
+ pidx.originalCase[lower] = m
+ pidx.names = append(pidx.names, m)
}
}
- return lookup
-}
-
-// pricedNamesFor 返回指定平台下已索引的精确模型名(保留原始大小写,按添加顺序)。
-// 它是从 pricingLookup 中取 keys 并回查原始 ModelPricing 以得到原样字符串。
-func pricedNamesFor(pricings []ChannelModelPricing, platform string) []string {
- seen := make(map[string]struct{})
- out := make([]string, 0)
- for i := range pricings {
- if pricings[i].Platform != platform {
- continue
- }
- for _, m := range pricings[i].Models {
- if _, wild := splitWildcardSuffix(m); wild {
- continue
- }
- lower := strings.ToLower(m)
- if _, ok := seen[lower]; ok {
- continue
- }
- seen[lower] = struct{}{}
- out = append(out, m)
- }
- }
- return out
+ return idx
}
// SupportedModels 计算渠道的支持模型列表,结果保证不含通配符。
@@ -452,16 +443,19 @@ func pricedNamesFor(pricings []ChannelModelPricing, platform string) []string {
// - 遍历 Channel.ModelMapping 的每个 platform 条目;
// - 映射 key 不带尾部 "*":直接作为一个支持模型名(即使没有匹配的定价行,也会产出 Pricing=nil 的条目);
// - 映射 key 带尾部 "*":用同 platform 的 ModelPricing.Models 做前缀匹配展开(定价中带 "*" 的条目被忽略,因为它们本身就是模式,不是具体模型名);
+// - 映射 key 为 `"*"`(单独一个星号)将展开为该平台所有定价模型(前缀为空 → 全匹配)。这是刻意行为,用于"将该平台所有模型透传"的场景;
// - 未在 ModelMapping 中出现的 platform 不会产出任何条目——这是**刻意设计**("没配映射就不显示"),即使该平台有定价行。
//
-// 每个结果尝试从 pricingLookup(平台+模型名索引)查找精确定价,未配置则 Pricing=nil。
+// 当映射 key(exact 或 wildcard 展开后的候选)能命中定价时,结果中的 Name 使用**定价的原始大小写**
+// (定价是模型身份的事实来源),否则保留映射 key 的原始大小写。
+// 每个结果尝试从 platform 索引查找精确定价,未配置则 Pricing=nil。
// 结果按 (Platform, Name) 稳定排序,并按 (Platform, lowercase(Name)) 去重。
func (c *Channel) SupportedModels() []SupportedModel {
if c == nil || len(c.ModelMapping) == 0 {
return nil
}
- lookup := buildPricingLookup(c.ModelPricing)
+ idx := buildPricingIndex(c.ModelPricing)
type dedupKey struct {
platform string
@@ -470,20 +464,23 @@ func (c *Channel) SupportedModels() []SupportedModel {
seen := make(map[dedupKey]struct{})
result := make([]SupportedModel, 0)
- add := func(platform, name string) {
+ add := func(platform, name string, pidx *platformPricingIndex) {
key := dedupKey{platform: platform, name: strings.ToLower(name)}
if _, ok := seen[key]; ok {
return
}
seen[key] = struct{}{}
var pricing *ChannelModelPricing
- if byModel, ok := lookup[platform]; ok {
- if p, ok := byModel[strings.ToLower(name)]; ok {
+ displayName := name
+ if pidx != nil {
+ lower := strings.ToLower(name)
+ if p, ok := pidx.byLower[lower]; ok {
pricing = p
+ displayName = pidx.originalCase[lower] // 定价大小写胜出
}
}
result = append(result, SupportedModel{
- Name: name,
+ Name: displayName,
Platform: platform,
Pricing: pricing,
})
@@ -493,19 +490,22 @@ func (c *Channel) SupportedModels() []SupportedModel {
if len(mapping) == 0 {
continue
}
- pricedNames := pricedNamesFor(c.ModelPricing, platform)
+ pidx := idx[platform] // 可能为 nil(该平台无定价行)
for src := range mapping {
prefix, isWild := splitWildcardSuffix(src)
if isWild {
+ if pidx == nil {
+ continue
+ }
prefixLower := strings.ToLower(prefix)
- for _, candidate := range pricedNames {
+ for _, candidate := range pidx.names {
if strings.HasPrefix(strings.ToLower(candidate), prefixLower) {
- add(platform, candidate)
+ add(platform, candidate, pidx)
}
}
continue
}
- add(platform, src)
+ add(platform, src, pidx)
}
}
diff --git a/backend/internal/service/channel_available.go b/backend/internal/service/channel_available.go
index 700380c2..8e055518 100644
--- a/backend/internal/service/channel_available.go
+++ b/backend/internal/service/channel_available.go
@@ -65,12 +65,17 @@ func (s *ChannelService) ListAvailable(ctx context.Context) ([]AvailableChannel,
}
sort.Slice(groups, func(i, j int) bool { return groups[i].Name < groups[j].Name })
+ billingSource := ch.BillingModelSource
+ if billingSource == "" {
+ billingSource = BillingModelSourceChannelMapped
+ }
+
out = append(out, AvailableChannel{
ID: ch.ID,
Name: ch.Name,
Description: ch.Description,
Status: ch.Status,
- BillingModelSource: ch.BillingModelSource,
+ BillingModelSource: billingSource,
RestrictModels: ch.RestrictModels,
Groups: groups,
SupportedModels: ch.SupportedModels(),
diff --git a/backend/internal/service/channel_test.go b/backend/internal/service/channel_test.go
index 812a3a63..7cb1b272 100644
--- a/backend/internal/service/channel_test.go
+++ b/backend/internal/service/channel_test.go
@@ -637,3 +637,34 @@ func TestSupportedModels_EmptyPlatformMapping(t *testing.T) {
}
require.Empty(t, ch.SupportedModels())
}
+
+func TestSupportedModels_ExactKeyUsesPricedCaseWhenAvailable(t *testing.T) {
+ // mapping key uses uppercase, pricing uses lowercase — pricing's case should win.
+ ch := &Channel{
+ ModelPricing: []ChannelModelPricing{
+ {ID: 1, Platform: "openai", Models: []string{"gpt-4o"}},
+ },
+ ModelMapping: map[string]map[string]string{
+ "openai": {"GPT-4o": "gpt-4o"},
+ },
+ }
+ got := ch.SupportedModels()
+ require.Len(t, got, 1)
+ require.Equal(t, "gpt-4o", got[0].Name) // pricing's case wins
+}
+
+func TestSupportedModels_AsteriskOnlyMappingExpandsAllPriced(t *testing.T) {
+ // 映射 key 为单独的 "*":前缀为空 → 命中该平台所有定价模型(透传场景)。
+ ch := &Channel{
+ ModelPricing: []ChannelModelPricing{
+ {ID: 1, Platform: "openai", Models: []string{"gpt-4o", "gpt-4o-mini"}},
+ },
+ ModelMapping: map[string]map[string]string{
+ "openai": {"*": "gpt-4o"},
+ },
+ }
+ got := ch.SupportedModels()
+ require.Len(t, got, 2)
+ names := []string{got[0].Name, got[1].Name}
+ require.ElementsMatch(t, []string{"gpt-4o", "gpt-4o-mini"}, names)
+}
diff --git a/frontend/src/api/admin/channels.ts b/frontend/src/api/admin/channels.ts
index eb7e91d8..7ad4af28 100644
--- a/frontend/src/api/admin/channels.ts
+++ b/frontend/src/api/admin/channels.ts
@@ -4,8 +4,9 @@
*/
import { apiClient } from '../client'
+import type { BillingMode, ChannelStatus, BillingModelSource } from '@/constants/channel'
-export type BillingMode = 'token' | 'per_request' | 'image'
+export type { BillingMode } from '@/constants/channel'
export interface PricingInterval {
id?: number
@@ -46,8 +47,8 @@ export interface Channel {
id: number
name: string
description: string
- status: string
- billing_model_source: string // "requested" | "upstream"
+ status: ChannelStatus
+ billing_model_source: BillingModelSource
restrict_models: boolean
features_config?: Record
group_ids: number[]
@@ -181,8 +182,8 @@ export interface AvailableChannel {
id: number
name: string
description: string
- status: string
- billing_model_source: string
+ status: ChannelStatus
+ billing_model_source: BillingModelSource
restrict_models: boolean
groups: AvailableGroupRef[]
supported_models: SupportedModel[]
diff --git a/frontend/src/components/channels/AvailableChannelsTable.vue b/frontend/src/components/channels/AvailableChannelsTable.vue
index 403391a3..13f5d71e 100644
--- a/frontend/src/components/channels/AvailableChannelsTable.vue
+++ b/frontend/src/components/channels/AvailableChannelsTable.vue
@@ -85,18 +85,15 @@ interface Column {
label: string
}
-withDefaults(
- defineProps<{
- columns: Column[]
- rows: Row[]
- loading: boolean
- pricingKeyPrefix: string
- noPricingLabel: string
- noModelsLabel: string
- emptyLabel: string
- }>(),
- { loading: false }
-)
+defineProps<{
+ columns: Column[]
+ rows: Row[]
+ loading: boolean
+ pricingKeyPrefix: string
+ noPricingLabel: string
+ noModelsLabel: string
+ emptyLabel: string
+}>()
const slots = useSlots()
/**
diff --git a/frontend/src/components/channels/PricingRow.vue b/frontend/src/components/channels/PricingRow.vue
index 8db077c0..4134593b 100644
--- a/frontend/src/components/channels/PricingRow.vue
+++ b/frontend/src/components/channels/PricingRow.vue
@@ -7,6 +7,7 @@
diff --git a/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts b/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts
index 4e194a39..ec4aed5d 100644
--- a/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts
+++ b/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts
@@ -2,7 +2,7 @@ import { mount } from '@vue/test-utils'
import { createPinia, setActivePinia } from 'pinia'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import ProfileIdentityBindingsSection from '@/components/user/profile/ProfileIdentityBindingsSection.vue'
-import { useAppStore } from '@/stores'
+import { useAppStore, useAuthStore } from '@/stores'
import type { User } from '@/types'
const routeState = vi.hoisted(() => ({
@@ -15,10 +15,24 @@ const locationState = vi.hoisted(() => ({
let pinia: ReturnType
+const userApiMocks = vi.hoisted(() => ({
+ sendEmailBindingCode: vi.fn(),
+ bindEmailIdentity: vi.fn(),
+}))
+
vi.mock('vue-router', () => ({
useRoute: () => routeState,
}))
+vi.mock('@/api/user', async (importOriginal) => {
+ const actual = await importOriginal()
+ return {
+ ...actual,
+ sendEmailBindingCode: (...args: any[]) => userApiMocks.sendEmailBindingCode(...args),
+ bindEmailIdentity: (...args: any[]) => userApiMocks.bindEmailIdentity(...args),
+ }
+})
+
vi.mock('vue-i18n', async (importOriginal) => {
const actual = await importOriginal()
return {
@@ -34,6 +48,13 @@ vi.mock('vue-i18n', async (importOriginal) => {
if (key === 'profile.authBindings.providers.wechat') return 'WeChat'
if (key === 'profile.authBindings.providers.oidc') return params?.providerName || 'OIDC'
if (key === 'profile.authBindings.bindAction') return `Bind ${params?.providerName || ''}`.trim()
+ if (key === 'profile.authBindings.emailPlaceholder') return 'Email address'
+ if (key === 'profile.authBindings.codePlaceholder') return 'Verification code'
+ if (key === 'profile.authBindings.passwordPlaceholder') return 'Set password'
+ if (key === 'profile.authBindings.sendCodeAction') return 'Send code'
+ if (key === 'profile.authBindings.confirmEmailBindAction') return 'Bind email'
+ if (key === 'profile.authBindings.codeSentTo') return `Code sent to ${params?.email || ''}`.trim()
+ if (key === 'profile.authBindings.bindSuccess') return 'Bind success'
return key
},
}),
@@ -76,6 +97,8 @@ describe('ProfileIdentityBindingsSection', () => {
const appStore = useAppStore()
appStore.cachedPublicSettings = null
appStore.publicSettingsLoaded = false
+ userApiMocks.sendEmailBindingCode.mockReset()
+ userApiMocks.bindEmailIdentity.mockReset()
})
afterEach(() => {
@@ -224,4 +247,58 @@ describe('ProfileIdentityBindingsSection', () => {
expect(wrapper.find('[data-testid="profile-binding-wechat-action"]').exists()).toBe(true)
})
+
+ it('sends email verification code and binds email from the profile card', async () => {
+ userApiMocks.sendEmailBindingCode.mockResolvedValue(undefined)
+ userApiMocks.bindEmailIdentity.mockResolvedValue(
+ createUser({
+ email: 'bound@example.com',
+ email_bound: true,
+ auth_bindings: {
+ email: { bound: true },
+ },
+ })
+ )
+
+ const appStore = useAppStore()
+ const authStore = useAuthStore()
+ authStore.user = createUser({
+ email: 'legacy-user@linuxdo-connect.invalid',
+ email_bound: false,
+ auth_bindings: {
+ email: { bound: false },
+ },
+ })
+ const showSuccessSpy = vi.spyOn(appStore, 'showSuccess')
+
+ const wrapper = mount(ProfileIdentityBindingsSection, {
+ global: {
+ plugins: [pinia],
+ },
+ props: {
+ user: authStore.user,
+ linuxdoEnabled: false,
+ oidcEnabled: false,
+ wechatEnabled: false,
+ },
+ })
+
+ await wrapper.get('[data-testid="profile-binding-email-input"]').setValue('bound@example.com')
+ await wrapper.get('[data-testid="profile-binding-email-send-code"]').trigger('click')
+
+ expect(userApiMocks.sendEmailBindingCode).toHaveBeenCalledWith('bound@example.com')
+ expect(showSuccessSpy).toHaveBeenCalledWith('Code sent to bound@example.com')
+
+ await wrapper.get('[data-testid="profile-binding-email-code-input"]').setValue('123456')
+ await wrapper.get('[data-testid="profile-binding-email-password-input"]').setValue('new-password')
+ await wrapper.get('[data-testid="profile-binding-email-submit"]').trigger('click')
+
+ expect(userApiMocks.bindEmailIdentity).toHaveBeenCalledWith({
+ email: 'bound@example.com',
+ verify_code: '123456',
+ password: 'new-password',
+ })
+ expect(wrapper.get('[data-testid="profile-binding-email-status"]').text()).toBe('Bound')
+ expect(authStore.user?.email).toBe('bound@example.com')
+ })
})
diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts
index ec9c1ea3..2b41a3c3 100644
--- a/frontend/src/i18n/locales/en.ts
+++ b/frontend/src/i18n/locales/en.ts
@@ -964,6 +964,12 @@ export default {
description: 'View current bindings and connect another provider to this account.',
bindAction: 'Bind {providerName}',
bindSuccess: 'Account linked successfully',
+ emailPlaceholder: 'Enter email address',
+ codePlaceholder: 'Enter verification code',
+ passwordPlaceholder: 'Set a login password',
+ sendCodeAction: 'Send code',
+ confirmEmailBindAction: 'Bind email',
+ codeSentTo: 'Code sent to {email}',
status: {
bound: 'Bound',
notBound: 'Not bound',
diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts
index 9941d323..b60a69d6 100644
--- a/frontend/src/i18n/locales/zh.ts
+++ b/frontend/src/i18n/locales/zh.ts
@@ -968,6 +968,12 @@ export default {
description: '查看当前绑定状态,并将更多第三方登录方式关联到这个账号。',
bindAction: '绑定 {providerName}',
bindSuccess: '账号绑定成功',
+ emailPlaceholder: '输入邮箱地址',
+ codePlaceholder: '输入验证码',
+ passwordPlaceholder: '设置登录密码',
+ sendCodeAction: '发送验证码',
+ confirmEmailBindAction: '绑定邮箱',
+ codeSentTo: '验证码已发送到 {email}',
status: {
bound: '已绑定',
notBound: '未绑定',
diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts
index 07341919..bfc11cb2 100644
--- a/frontend/src/types/index.ts
+++ b/frontend/src/types/index.ts
@@ -118,6 +118,8 @@ export interface RegisterRequest {
export interface SendVerifyCodeRequest {
email: string
turnstile_token?: string
+ pending_auth_token?: string
+ pending_oauth_token?: string
}
export interface SendVerifyCodeResponse {
diff --git a/frontend/src/views/auth/EmailVerifyView.vue b/frontend/src/views/auth/EmailVerifyView.vue
index 84dd4667..d7bf6b7a 100644
--- a/frontend/src/views/auth/EmailVerifyView.vue
+++ b/frontend/src/views/auth/EmailVerifyView.vue
@@ -176,7 +176,12 @@ import { AuthLayout } from '@/components/layout'
import Icon from '@/components/icons/Icon.vue'
import TurnstileWidget from '@/components/TurnstileWidget.vue'
import { useAuthStore, useAppStore } from '@/stores'
-import { persistOAuthTokenContext, getPublicSettings, sendVerifyCode } from '@/api/auth'
+import {
+ persistOAuthTokenContext,
+ getPublicSettings,
+ sendPendingOAuthVerifyCode,
+ sendVerifyCode,
+} from '@/api/auth'
import { apiClient } from '@/api/client'
import { buildAuthErrorMessage } from '@/utils/authError'
import {
@@ -355,18 +360,21 @@ async function sendCode(): Promise {
errorMessage.value = ''
try {
- if (!isRegistrationEmailSuffixAllowed(email.value, registrationEmailSuffixWhitelist.value)) {
+ if (!pendingAuthToken.value && !isRegistrationEmailSuffixAllowed(email.value, registrationEmailSuffixWhitelist.value)) {
errorMessage.value = buildEmailSuffixNotAllowedMessage()
appStore.showError(errorMessage.value)
return
}
- const response = await sendVerifyCode({
+ const requestPayload = {
email: email.value,
[pendingAuthTokenField.value]: pendingAuthToken.value || undefined,
// 优先使用重发时新获取的 token(因为初始 token 可能已被使用)
turnstile_token: resendTurnstileToken.value || initialTurnstileToken.value || undefined
- } as Parameters[0])
+ } as Parameters[0]
+ const response = pendingAuthToken.value
+ ? await sendPendingOAuthVerifyCode(requestPayload)
+ : await sendVerifyCode(requestPayload)
codeSent.value = true
startCountdown(response.countdown)
diff --git a/frontend/src/views/auth/LinuxDoCallbackView.vue b/frontend/src/views/auth/LinuxDoCallbackView.vue
index 6c923b0a..735c6582 100644
--- a/frontend/src/views/auth/LinuxDoCallbackView.vue
+++ b/frontend/src/views/auth/LinuxDoCallbackView.vue
@@ -444,6 +444,28 @@ function getRequestErrorMessage(error: unknown, fallback: string): string {
return err.response?.data?.detail || err.response?.data?.message || err.message || fallback
}
+function isCreateAccountRecoveryError(error: unknown): boolean {
+ const data = (error as {
+ response?: {
+ data?: {
+ reason?: string
+ error?: string
+ code?: string
+ step?: string
+ intent?: string
+ }
+ }
+ }).response?.data
+ const states = [data?.reason, data?.error, data?.code, data?.step, data?.intent]
+ .map(value => value?.trim().toLowerCase())
+ .filter((value): value is string => Boolean(value))
+
+ return states.includes('email_exists') ||
+ states.includes('bind_login_required') ||
+ states.includes('bind_login') ||
+ states.includes('adopt_existing_user_by_email')
+}
+
async function finalizeCompletion(completion: PendingOAuthExchangeResponse, redirect: string) {
if (getOAuthCompletionKind(completion) === 'bind') {
const bindRedirect = sanitizeRedirectPath(completion.redirect || '/profile')
@@ -540,10 +562,15 @@ async function handleCreateAccount(payload: PendingOAuthCreateAccountPayload) {
email: payload.email,
password: payload.password,
verify_code: payload.verifyCode || undefined,
+ invitation_code: payload.invitationCode || undefined,
...serializeAdoptionDecision(currentAdoptionDecision())
})
await finalizePendingAccountResponse(data)
} catch (e: unknown) {
+ if (isCreateAccountRecoveryError(e)) {
+ switchToBindLoginMode(payload.email)
+ return
+ }
accountActionError.value = getRequestErrorMessage(e, t('auth.loginFailed'))
} finally {
isSubmitting.value = false
diff --git a/frontend/src/views/auth/OidcCallbackView.vue b/frontend/src/views/auth/OidcCallbackView.vue
index 840e4964..019cab54 100644
--- a/frontend/src/views/auth/OidcCallbackView.vue
+++ b/frontend/src/views/auth/OidcCallbackView.vue
@@ -488,6 +488,28 @@ function getRequestErrorMessage(error: unknown, fallback: string): string {
return err.response?.data?.detail || err.response?.data?.message || err.message || fallback
}
+function isCreateAccountRecoveryError(error: unknown): boolean {
+ const data = (error as {
+ response?: {
+ data?: {
+ reason?: string
+ error?: string
+ code?: string
+ step?: string
+ intent?: string
+ }
+ }
+ }).response?.data
+ const states = [data?.reason, data?.error, data?.code, data?.step, data?.intent]
+ .map(value => value?.trim().toLowerCase())
+ .filter((value): value is string => Boolean(value))
+
+ return states.includes('email_exists') ||
+ states.includes('bind_login_required') ||
+ states.includes('bind_login') ||
+ states.includes('adopt_existing_user_by_email')
+}
+
async function finalizeCompletion(completion: PendingOAuthExchangeResponse, redirect: string) {
if (getOAuthCompletionKind(completion) === 'bind') {
const bindRedirect = sanitizeRedirectPath(completion.redirect || '/profile')
@@ -584,10 +606,15 @@ async function handleCreateAccount(payload: PendingOAuthCreateAccountPayload) {
email: payload.email,
password: payload.password,
verify_code: payload.verifyCode || undefined,
+ invitation_code: payload.invitationCode || undefined,
...serializeAdoptionDecision(currentAdoptionDecision())
})
await finalizePendingAccountResponse(data)
} catch (e: unknown) {
+ if (isCreateAccountRecoveryError(e)) {
+ switchToBindLoginMode(payload.email)
+ return
+ }
accountActionError.value = getRequestErrorMessage(e, t('auth.loginFailed'))
} finally {
isSubmitting.value = false
diff --git a/frontend/src/views/auth/WechatCallbackView.vue b/frontend/src/views/auth/WechatCallbackView.vue
index 35cd0032..36e3140c 100644
--- a/frontend/src/views/auth/WechatCallbackView.vue
+++ b/frontend/src/views/auth/WechatCallbackView.vue
@@ -647,6 +647,28 @@ function getRequestErrorMessage(error: unknown, fallback: string): string {
return err.response?.data?.detail || err.response?.data?.message || err.message || fallback
}
+function isCreateAccountRecoveryError(error: unknown): boolean {
+ const data = (error as {
+ response?: {
+ data?: {
+ reason?: string
+ error?: string
+ code?: string
+ step?: string
+ intent?: string
+ }
+ }
+ }).response?.data
+ const states = [data?.reason, data?.error, data?.code, data?.step, data?.intent]
+ .map(value => value?.trim().toLowerCase())
+ .filter((value): value is string => Boolean(value))
+
+ return states.includes('email_exists') ||
+ states.includes('bind_login_required') ||
+ states.includes('bind_login') ||
+ states.includes('adopt_existing_user_by_email')
+}
+
async function finalizeCompletion(completion: PendingOAuthExchangeResponse, redirect: string) {
if (getOAuthCompletionKind(completion) === 'bind') {
const bindRedirect = sanitizeRedirectPath(completion.redirect || '/profile')
@@ -739,10 +761,15 @@ async function handleCreateAccount(payload: PendingOAuthCreateAccountPayload) {
email: payload.email,
password: payload.password,
verify_code: payload.verifyCode || undefined,
+ invitation_code: payload.invitationCode || undefined,
...serializeAdoptionDecision(currentAdoptionDecision())
})
await finalizePendingAccountResponse(data)
} catch (e: unknown) {
+ if (isCreateAccountRecoveryError(e)) {
+ switchToBindLoginMode(payload.email)
+ return
+ }
accountActionError.value = getRequestErrorMessage(e, t('auth.loginFailed'))
} finally {
isSubmitting.value = false
diff --git a/frontend/src/views/auth/__tests__/EmailVerifyView.spec.ts b/frontend/src/views/auth/__tests__/EmailVerifyView.spec.ts
index f6dff076..c231d6e7 100644
--- a/frontend/src/views/auth/__tests__/EmailVerifyView.spec.ts
+++ b/frontend/src/views/auth/__tests__/EmailVerifyView.spec.ts
@@ -11,6 +11,7 @@ const {
clearPendingAuthSessionMock,
getPublicSettingsMock,
sendVerifyCodeMock,
+ sendPendingOAuthVerifyCodeMock,
persistOAuthTokenContextMock,
apiClientPostMock,
authStoreState,
@@ -23,6 +24,7 @@ const {
clearPendingAuthSessionMock: vi.fn(),
getPublicSettingsMock: vi.fn(),
sendVerifyCodeMock: vi.fn(),
+ sendPendingOAuthVerifyCodeMock: vi.fn(),
persistOAuthTokenContextMock: vi.fn(),
apiClientPostMock: vi.fn(),
authStoreState: {
@@ -80,6 +82,7 @@ vi.mock('@/api/auth', async () => {
...actual,
getPublicSettings: (...args: any[]) => getPublicSettingsMock(...args),
sendVerifyCode: (...args: any[]) => sendVerifyCodeMock(...args),
+ sendPendingOAuthVerifyCode: (...args: any[]) => sendPendingOAuthVerifyCodeMock(...args),
persistOAuthTokenContext: (...args: any[]) => persistOAuthTokenContextMock(...args),
}
})
@@ -100,6 +103,7 @@ describe('EmailVerifyView', () => {
clearPendingAuthSessionMock.mockReset()
getPublicSettingsMock.mockReset()
sendVerifyCodeMock.mockReset()
+ sendPendingOAuthVerifyCodeMock.mockReset()
persistOAuthTokenContextMock.mockReset()
apiClientPostMock.mockReset()
authStoreState.pendingAuthSession = null
@@ -112,9 +116,86 @@ describe('EmailVerifyView', () => {
registration_email_suffix_whitelist: [],
})
sendVerifyCodeMock.mockResolvedValue({ countdown: 60 })
+ sendPendingOAuthVerifyCodeMock.mockResolvedValue({ countdown: 60 })
setTokenMock.mockResolvedValue({})
})
+ it('uses the pending oauth verify-code endpoint when register data carries a pending auth session', async () => {
+ authStoreState.pendingAuthSession = {
+ token: 'pending-token-1',
+ token_field: 'pending_auth_token',
+ provider: 'wechat',
+ redirect: '/profile',
+ }
+ sessionStorage.setItem(
+ 'register_data',
+ JSON.stringify({
+ email: 'fresh@example.com',
+ password: 'secret-123',
+ })
+ )
+
+ mount(EmailVerifyView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ TurnstileWidget: true,
+ transition: false,
+ },
+ },
+ })
+
+ await flushPromises()
+
+ expect(sendPendingOAuthVerifyCodeMock).toHaveBeenCalledWith({
+ email: 'fresh@example.com',
+ pending_auth_token: 'pending-token-1',
+ })
+ expect(sendVerifyCodeMock).not.toHaveBeenCalled()
+ })
+
+ it('skips the registration email suffix whitelist for pending oauth verification', async () => {
+ authStoreState.pendingAuthSession = {
+ token: 'pending-token-2',
+ token_field: 'pending_auth_token',
+ provider: 'oidc',
+ redirect: '/profile',
+ }
+ getPublicSettingsMock.mockResolvedValue({
+ turnstile_enabled: false,
+ turnstile_site_key: '',
+ site_name: 'Sub2API',
+ registration_email_suffix_whitelist: ['allowed.com'],
+ })
+ sessionStorage.setItem(
+ 'register_data',
+ JSON.stringify({
+ email: 'fresh@example.com',
+ password: 'secret-123',
+ })
+ )
+
+ mount(EmailVerifyView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ TurnstileWidget: true,
+ transition: false,
+ },
+ },
+ })
+
+ await flushPromises()
+
+ expect(sendPendingOAuthVerifyCodeMock).toHaveBeenCalledWith({
+ email: 'fresh@example.com',
+ pending_auth_token: 'pending-token-2',
+ })
+ expect(showErrorMock).not.toHaveBeenCalled()
+ })
+
it('submits pending auth account creation when session storage has no pending metadata but auth store does', async () => {
authStoreState.pendingAuthSession = {
token: 'pending-token-1',
diff --git a/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts b/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts
index a04915b7..f612681a 100644
--- a/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts
+++ b/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts
@@ -15,6 +15,7 @@ const getPublicSettings = vi.fn()
const login2FA = vi.fn()
const apiClientPost = vi.fn()
const sendVerifyCode = vi.fn()
+const sendPendingOAuthVerifyCode = vi.fn()
vi.mock('vue-router', () => ({
useRoute: () => ({
@@ -61,7 +62,8 @@ vi.mock('@/api/auth', async () => {
completeLinuxDoOAuthRegistration: (...args: any[]) => completeLinuxDoOAuthRegistration(...args),
getPublicSettings: (...args: any[]) => getPublicSettings(...args),
login2FA: (...args: any[]) => login2FA(...args),
- sendVerifyCode: (...args: any[]) => sendVerifyCode(...args)
+ sendVerifyCode: (...args: any[]) => sendVerifyCode(...args),
+ sendPendingOAuthVerifyCode: (...args: any[]) => sendPendingOAuthVerifyCode(...args)
}
})
@@ -79,6 +81,7 @@ describe('LinuxDoCallbackView', () => {
login2FA.mockReset()
apiClientPost.mockReset()
sendVerifyCode.mockReset()
+ sendPendingOAuthVerifyCode.mockReset()
getPublicSettings.mockResolvedValue({
turnstile_enabled: false,
turnstile_site_key: ''
@@ -334,6 +337,11 @@ describe('LinuxDoCallbackView', () => {
})
it('collects email, password, and verify code for pending oauth account creation and submits adoption decisions', async () => {
+ getPublicSettings.mockResolvedValue({
+ invitation_code_enabled: true,
+ turnstile_enabled: false,
+ turnstile_site_key: ''
+ })
exchangePendingOAuthCompletion.mockResolvedValue({
error: 'email_required',
redirect: '/welcome',
@@ -370,6 +378,7 @@ describe('LinuxDoCallbackView', () => {
await wrapper.get('[data-testid="linuxdo-create-account-email"]').setValue(' new@example.com ')
await wrapper.get('[data-testid="linuxdo-create-account-password"]').setValue('secret-123')
await wrapper.get('[data-testid="linuxdo-create-account-verify-code"]').setValue('246810')
+ await wrapper.get('[data-testid="linuxdo-create-account-invitation-code"]').setValue(' INVITE123 ')
await wrapper.get('[data-testid="linuxdo-create-account-submit"]').trigger('click')
await flushPromises()
@@ -377,6 +386,7 @@ describe('LinuxDoCallbackView', () => {
email: 'new@example.com',
password: 'secret-123',
verify_code: '246810',
+ invitation_code: 'INVITE123',
adopt_display_name: true,
adopt_avatar: false
})
@@ -384,12 +394,48 @@ describe('LinuxDoCallbackView', () => {
expect(replace).toHaveBeenCalledWith('/welcome')
})
+ it('switches to bind-login when create-account returns EMAIL_EXISTS', async () => {
+ exchangePendingOAuthCompletion.mockResolvedValue({
+ error: 'email_required',
+ redirect: '/welcome'
+ })
+ apiClientPost.mockRejectedValue({
+ response: {
+ data: {
+ reason: 'EMAIL_EXISTS',
+ message: 'email already exists'
+ }
+ }
+ })
+
+ const wrapper = mount(LinuxDoCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false
+ }
+ }
+ })
+
+ await flushPromises()
+ await wrapper.get('[data-testid="linuxdo-create-account-email"]').setValue('existing@example.com')
+ await wrapper.get('[data-testid="linuxdo-create-account-password"]').setValue('secret-123')
+ await wrapper.get('[data-testid="linuxdo-create-account-submit"]').trigger('click')
+ await flushPromises()
+
+ expect((wrapper.get('[data-testid="linuxdo-bind-login-email"]').element as HTMLInputElement).value).toBe(
+ 'existing@example.com'
+ )
+ })
+
it('sends a verify code for pending oauth account creation', async () => {
exchangePendingOAuthCompletion.mockResolvedValue({
error: 'email_required',
redirect: '/welcome'
})
- sendVerifyCode.mockResolvedValue({
+ sendPendingOAuthVerifyCode.mockResolvedValue({
message: 'sent',
countdown: 60
})
@@ -411,7 +457,7 @@ describe('LinuxDoCallbackView', () => {
await wrapper.get('[data-testid="linuxdo-create-account-send-code"]').trigger('click')
await flushPromises()
- expect(sendVerifyCode).toHaveBeenCalledWith({
+ expect(sendPendingOAuthVerifyCode).toHaveBeenCalledWith({
email: 'new@example.com'
})
})
diff --git a/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts b/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts
index 259fb282..0edcb931 100644
--- a/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts
+++ b/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts
@@ -15,6 +15,7 @@ const getPublicSettings = vi.fn()
const login2FA = vi.fn()
const apiClientPost = vi.fn()
const sendVerifyCode = vi.fn()
+const sendPendingOAuthVerifyCode = vi.fn()
vi.mock('vue-router', () => ({
useRoute: () => ({
@@ -66,7 +67,8 @@ vi.mock('@/api/auth', async () => {
completeOIDCOAuthRegistration: (...args: any[]) => completeOIDCOAuthRegistration(...args),
getPublicSettings: (...args: any[]) => getPublicSettings(...args),
login2FA: (...args: any[]) => login2FA(...args),
- sendVerifyCode: (...args: any[]) => sendVerifyCode(...args)
+ sendVerifyCode: (...args: any[]) => sendVerifyCode(...args),
+ sendPendingOAuthVerifyCode: (...args: any[]) => sendPendingOAuthVerifyCode(...args)
}
})
@@ -84,6 +86,7 @@ describe('OidcCallbackView', () => {
login2FA.mockReset()
apiClientPost.mockReset()
sendVerifyCode.mockReset()
+ sendPendingOAuthVerifyCode.mockReset()
getPublicSettings.mockResolvedValue({
oidc_oauth_provider_name: 'ExampleID',
turnstile_enabled: false,
@@ -312,6 +315,12 @@ describe('OidcCallbackView', () => {
})
it('collects email, password, and verify code for pending oauth account creation and submits adoption decisions', async () => {
+ getPublicSettings.mockResolvedValue({
+ oidc_oauth_provider_name: 'ExampleID',
+ invitation_code_enabled: true,
+ turnstile_enabled: false,
+ turnstile_site_key: ''
+ })
exchangePendingOAuthCompletion.mockResolvedValue({
error: 'email_required',
redirect: '/welcome',
@@ -348,6 +357,7 @@ describe('OidcCallbackView', () => {
await wrapper.get('[data-testid="oidc-create-account-email"]').setValue(' new@example.com ')
await wrapper.get('[data-testid="oidc-create-account-password"]').setValue('secret-123')
await wrapper.get('[data-testid="oidc-create-account-verify-code"]').setValue('246810')
+ await wrapper.get('[data-testid="oidc-create-account-invitation-code"]').setValue(' INVITE123 ')
await wrapper.get('[data-testid="oidc-create-account-submit"]').trigger('click')
await flushPromises()
@@ -355,6 +365,7 @@ describe('OidcCallbackView', () => {
email: 'new@example.com',
password: 'secret-123',
verify_code: '246810',
+ invitation_code: 'INVITE123',
adopt_display_name: true,
adopt_avatar: false
})
@@ -362,12 +373,48 @@ describe('OidcCallbackView', () => {
expect(replace).toHaveBeenCalledWith('/welcome')
})
+ it('switches to bind-login when create-account returns EMAIL_EXISTS', async () => {
+ exchangePendingOAuthCompletion.mockResolvedValue({
+ error: 'email_required',
+ redirect: '/welcome'
+ })
+ apiClientPost.mockRejectedValue({
+ response: {
+ data: {
+ reason: 'EMAIL_EXISTS',
+ message: 'email already exists'
+ }
+ }
+ })
+
+ const wrapper = mount(OidcCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false
+ }
+ }
+ })
+
+ await flushPromises()
+ await wrapper.get('[data-testid="oidc-create-account-email"]').setValue('existing@example.com')
+ await wrapper.get('[data-testid="oidc-create-account-password"]').setValue('secret-123')
+ await wrapper.get('[data-testid="oidc-create-account-submit"]').trigger('click')
+ await flushPromises()
+
+ expect((wrapper.get('[data-testid="oidc-bind-login-email"]').element as HTMLInputElement).value).toBe(
+ 'existing@example.com'
+ )
+ })
+
it('sends a verify code for pending oauth account creation', async () => {
exchangePendingOAuthCompletion.mockResolvedValue({
error: 'email_required',
redirect: '/welcome'
})
- sendVerifyCode.mockResolvedValue({
+ sendPendingOAuthVerifyCode.mockResolvedValue({
message: 'sent',
countdown: 60
})
@@ -389,7 +436,7 @@ describe('OidcCallbackView', () => {
await wrapper.get('[data-testid="oidc-create-account-send-code"]').trigger('click')
await flushPromises()
- expect(sendVerifyCode).toHaveBeenCalledWith({
+ expect(sendPendingOAuthVerifyCode).toHaveBeenCalledWith({
email: 'new@example.com'
})
})
diff --git a/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts
index fed88890..e02060f6 100644
--- a/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts
+++ b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts
@@ -8,6 +8,8 @@ const {
login2FAMock,
apiClientPostMock,
sendVerifyCodeMock,
+ sendPendingOAuthVerifyCodeMock,
+ getPublicSettingsMock,
prepareOAuthBindAccessTokenCookieMock,
getAuthTokenMock,
replaceMock,
@@ -24,6 +26,8 @@ const {
login2FAMock: vi.fn(),
apiClientPostMock: vi.fn(),
sendVerifyCodeMock: vi.fn(),
+ sendPendingOAuthVerifyCodeMock: vi.fn(),
+ getPublicSettingsMock: vi.fn(),
prepareOAuthBindAccessTokenCookieMock: vi.fn(),
getAuthTokenMock: vi.fn(),
replaceMock: vi.fn(),
@@ -130,6 +134,8 @@ vi.mock('@/api/auth', async () => {
completeWeChatOAuthRegistration: (...args: any[]) => completeWeChatOAuthRegistrationMock(...args),
login2FA: (...args: any[]) => login2FAMock(...args),
sendVerifyCode: (...args: any[]) => sendVerifyCodeMock(...args),
+ sendPendingOAuthVerifyCode: (...args: any[]) => sendPendingOAuthVerifyCodeMock(...args),
+ getPublicSettings: (...args: any[]) => getPublicSettingsMock(...args),
prepareOAuthBindAccessTokenCookie: (...args: any[]) => prepareOAuthBindAccessTokenCookieMock(...args),
getAuthToken: (...args: any[]) => getAuthTokenMock(...args),
}
@@ -142,6 +148,8 @@ describe('WechatCallbackView', () => {
login2FAMock.mockReset()
apiClientPostMock.mockReset()
sendVerifyCodeMock.mockReset()
+ sendPendingOAuthVerifyCodeMock.mockReset()
+ getPublicSettingsMock.mockReset()
replaceMock.mockReset()
setTokenMock.mockReset()
showSuccessMock.mockReset()
@@ -167,6 +175,11 @@ describe('WechatCallbackView', () => {
configurable: true,
value: 'Mozilla/5.0',
})
+ getPublicSettingsMock.mockResolvedValue({
+ invitation_code_enabled: false,
+ turnstile_enabled: false,
+ turnstile_site_key: '',
+ })
})
it('overrides an incompatible query mode with the configured open capability during bind recovery', async () => {
@@ -478,6 +491,11 @@ describe('WechatCallbackView', () => {
})
it('collects email, password, and verify code for pending oauth account creation and submits adoption decisions', async () => {
+ getPublicSettingsMock.mockResolvedValue({
+ invitation_code_enabled: true,
+ turnstile_enabled: false,
+ turnstile_site_key: '',
+ })
exchangePendingOAuthCompletionMock.mockResolvedValue({
error: 'email_required',
redirect: '/welcome',
@@ -514,6 +532,7 @@ describe('WechatCallbackView', () => {
await wrapper.get('[data-testid="wechat-create-account-email"]').setValue(' new@example.com ')
await wrapper.get('[data-testid="wechat-create-account-password"]').setValue('secret-123')
await wrapper.get('[data-testid="wechat-create-account-verify-code"]').setValue('246810')
+ await wrapper.get('[data-testid="wechat-create-account-invitation-code"]').setValue(' INVITE123 ')
await wrapper.get('[data-testid="wechat-create-account-submit"]').trigger('click')
await flushPromises()
@@ -521,6 +540,7 @@ describe('WechatCallbackView', () => {
email: 'new@example.com',
password: 'secret-123',
verify_code: '246810',
+ invitation_code: 'INVITE123',
adopt_display_name: true,
adopt_avatar: false,
})
@@ -528,12 +548,48 @@ describe('WechatCallbackView', () => {
expect(replaceMock).toHaveBeenCalledWith('/welcome')
})
+ it('switches to bind-login when create-account returns EMAIL_EXISTS', async () => {
+ exchangePendingOAuthCompletionMock.mockResolvedValue({
+ error: 'email_required',
+ redirect: '/welcome',
+ })
+ apiClientPostMock.mockRejectedValue({
+ response: {
+ data: {
+ reason: 'EMAIL_EXISTS',
+ message: 'email already exists',
+ },
+ },
+ })
+
+ const wrapper = mount(WechatCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false,
+ },
+ },
+ })
+
+ await flushPromises()
+ await wrapper.get('[data-testid="wechat-create-account-email"]').setValue('existing@example.com')
+ await wrapper.get('[data-testid="wechat-create-account-password"]').setValue('secret-123')
+ await wrapper.get('[data-testid="wechat-create-account-submit"]').trigger('click')
+ await flushPromises()
+
+ expect((wrapper.get('[data-testid="wechat-bind-login-email"]').element as HTMLInputElement).value).toBe(
+ 'existing@example.com'
+ )
+ })
+
it('sends a verify code for pending oauth account creation', async () => {
exchangePendingOAuthCompletionMock.mockResolvedValue({
error: 'email_required',
redirect: '/welcome',
})
- sendVerifyCodeMock.mockResolvedValue({
+ sendPendingOAuthVerifyCodeMock.mockResolvedValue({
message: 'sent',
countdown: 60,
})
@@ -555,7 +611,7 @@ describe('WechatCallbackView', () => {
await wrapper.get('[data-testid="wechat-create-account-send-code"]').trigger('click')
await flushPromises()
- expect(sendVerifyCodeMock).toHaveBeenCalledWith({
+ expect(sendPendingOAuthVerifyCodeMock).toHaveBeenCalledWith({
email: 'new@example.com',
})
})
--
GitLab
From 7e89bca5e660b69de795a8a634fa17f5c24bea14 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 10:41:29 +0800
Subject: [PATCH 125/261] fix: tighten pending oauth email routing and binding
state
---
.../handler/auth_oauth_pending_flow.go | 163 ++++++++++++++++--
.../handler/auth_oauth_pending_flow_test.go | 148 ++++++++++++++--
.../internal/service/auth_oauth_email_flow.go | 122 ++++++++++++-
backend/internal/service/user_service.go | 29 +++-
frontend/src/api/auth.ts | 10 +-
.../ProfileIdentityBindingsSection.vue | 7 +-
.../ProfileIdentityBindingsSection.spec.ts | 23 +++
frontend/src/views/auth/EmailVerifyView.vue | 78 ++++++++-
.../auth/__tests__/EmailVerifyView.spec.ts | 159 +++++++++++++++++
9 files changed, 685 insertions(+), 54 deletions(-)
diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go
index 1d3b113f..8a3006f3 100644
--- a/backend/internal/handler/auth_oauth_pending_flow.go
+++ b/backend/internal/handler/auth_oauth_pending_flow.go
@@ -8,6 +8,7 @@ import (
"net/http"
"net/url"
"strings"
+ "time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
@@ -35,6 +36,8 @@ const (
oauthCompletionResponseKey = "completion_response"
)
+var pendingOAuthCreateAccountPreCommitHook func(context.Context, *dbent.PendingAuthSession) error
+
type oauthPendingSessionPayload struct {
Intent string
Identity service.PendingAuthIdentityKey
@@ -481,6 +484,26 @@ func (h *AuthHandler) SendPendingOAuthVerifyCode(c *gin.Context) {
return
}
+ client := h.entClient()
+ if client == nil {
+ response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready"))
+ return
+ }
+
+ email := strings.TrimSpace(strings.ToLower(req.Email))
+ if existingUser, err := findUserByNormalizedEmail(c.Request.Context(), client, email); err == nil && existingUser != nil {
+ session, err = h.transitionPendingOAuthAccountToBindLogin(c, client, session, email, oauthAdoptionDecisionRequest{})
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(session))
+ return
+ } else if err != nil && !errors.Is(err, service.ErrUserNotFound) {
+ response.ErrorFrom(c, err)
+ return
+ }
+
result, err := h.authService.SendPendingOAuthVerifyCode(c.Request.Context(), req.Email)
if err != nil {
response.ErrorFrom(c, err)
@@ -946,11 +969,46 @@ func applyPendingOAuthBinding(
return nil
}
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ return applyPendingOAuthBindingTx(ctx, tx, authService, userService, session, decision, overrideUserID, forceBind, applyFirstBindDefaults)
+ }
+
+ tx, err := client.Tx(ctx)
+ if err != nil {
+ return err
+ }
+ defer func() { _ = tx.Rollback() }()
+
+ txCtx := dbent.NewTxContext(ctx, tx)
+ if err := applyPendingOAuthBindingTx(txCtx, tx, authService, userService, session, decision, overrideUserID, forceBind, applyFirstBindDefaults); err != nil {
+ return err
+ }
+ return tx.Commit()
+}
+
+func applyPendingOAuthBindingTx(
+ ctx context.Context,
+ tx *dbent.Tx,
+ authService *service.AuthService,
+ userService *service.UserService,
+ session *dbent.PendingAuthSession,
+ decision *dbent.IdentityAdoptionDecision,
+ overrideUserID *int64,
+ forceBind bool,
+ applyFirstBindDefaults bool,
+) error {
+ if tx == nil || session == nil {
+ return nil
+ }
+ if !forceBind && !shouldBindPendingOAuthIdentity(session, decision) {
+ return nil
+ }
+
targetUserID := int64(0)
if overrideUserID != nil && *overrideUserID > 0 {
targetUserID = *overrideUserID
} else {
- resolvedUserID, err := resolvePendingOAuthTargetUserID(ctx, client, session)
+ resolvedUserID, err := resolvePendingOAuthTargetUserID(ctx, tx.Client(), session)
if err != nil {
return err
}
@@ -974,22 +1032,15 @@ func applyPendingOAuthBinding(
}
}
- tx, err := client.Tx(ctx)
- if err != nil {
- return err
- }
- defer func() { _ = tx.Rollback() }()
- txCtx := dbent.NewTxContext(ctx, tx)
-
if decision != nil && decision.AdoptDisplayName && adoptedDisplayName != "" {
if err := tx.Client().User.UpdateOneID(targetUserID).
SetUsername(adoptedDisplayName).
- Exec(txCtx); err != nil {
+ Exec(ctx); err != nil {
return err
}
}
- identity, err := ensurePendingOAuthIdentityForUser(txCtx, tx, session, targetUserID)
+ identity, err := ensurePendingOAuthIdentityForUser(ctx, tx, session, targetUserID)
if err != nil {
return err
}
@@ -1009,31 +1060,71 @@ func applyPendingOAuthBinding(
if issuer := oauthIdentityIssuer(session); issuer != nil {
updateIdentity = updateIdentity.SetIssuer(strings.TrimSpace(*issuer))
}
- if _, err := updateIdentity.Save(txCtx); err != nil {
+ if _, err := updateIdentity.Save(ctx); err != nil {
return err
}
if decision != nil && (decision.IdentityID == nil || *decision.IdentityID != identity.ID) {
if _, err := tx.Client().IdentityAdoptionDecision.UpdateOneID(decision.ID).
SetIdentityID(identity.ID).
- Save(txCtx); err != nil {
+ Save(ctx); err != nil {
return err
}
}
if applyFirstBindDefaults && authService != nil {
- if err := authService.ApplyProviderDefaultSettingsOnFirstBind(txCtx, targetUserID, session.ProviderType); err != nil {
+ if err := authService.ApplyProviderDefaultSettingsOnFirstBind(ctx, targetUserID, session.ProviderType); err != nil {
return err
}
}
if shouldAdoptAvatar && userService != nil {
- if _, err := userService.SetAvatar(txCtx, targetUserID, adoptedAvatarURL); err != nil {
+ if _, err := userService.SetAvatar(ctx, targetUserID, adoptedAvatarURL); err != nil {
return err
}
}
- return tx.Commit()
+ return nil
+}
+
+func consumePendingOAuthBrowserSessionTx(
+ ctx context.Context,
+ tx *dbent.Tx,
+ session *dbent.PendingAuthSession,
+) error {
+ if tx == nil || session == nil {
+ return service.ErrPendingAuthSessionNotFound
+ }
+
+ storedSession, err := tx.Client().PendingAuthSession.Get(ctx, session.ID)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return service.ErrPendingAuthSessionNotFound
+ }
+ return err
+ }
+
+ now := time.Now().UTC()
+ if storedSession.ConsumedAt != nil {
+ return service.ErrPendingAuthSessionConsumed
+ }
+ if !storedSession.ExpiresAt.IsZero() && now.After(storedSession.ExpiresAt) {
+ return service.ErrPendingAuthSessionExpired
+ }
+ if strings.TrimSpace(storedSession.BrowserSessionKey) != "" &&
+ strings.TrimSpace(storedSession.BrowserSessionKey) != strings.TrimSpace(session.BrowserSessionKey) {
+ return service.ErrPendingAuthBrowserMismatch
+ }
+
+ if _, err := tx.Client().PendingAuthSession.UpdateOneID(storedSession.ID).
+ SetConsumedAt(now).
+ SetCompletionCodeHash("").
+ ClearCompletionCodeExpiresAt().
+ Save(ctx); err != nil {
+ return err
+ }
+
+ return nil
}
func applyPendingOAuthAdoption(
@@ -1256,7 +1347,7 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
return
}
- pendingSvc, session, clearCookies, err := readPendingOAuthBrowserSession(c, h)
+ _, session, clearCookies, err := readPendingOAuthBrowserSession(c, h)
if err != nil {
response.ErrorFrom(c, err)
return
@@ -1341,7 +1432,20 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
response.ErrorFrom(c, err)
return
}
- if err := applyPendingOAuthBinding(c.Request.Context(), client, h.authService, h.userService, session, decision, &user.ID, true, false); err != nil {
+
+ tx, err := client.Tx(c.Request.Context())
+ if err != nil {
+ if rollbackCreatedUser(err) {
+ return
+ }
+ response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
+ return
+ }
+ defer func() { _ = tx.Rollback() }()
+ txCtx := dbent.NewTxContext(c.Request.Context(), tx)
+
+ if err := applyPendingOAuthBinding(txCtx, client, h.authService, h.userService, session, decision, &user.ID, true, false); err != nil {
+ _ = tx.Rollback()
if rollbackCreatedUser(err) {
return
}
@@ -1350,11 +1454,12 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
}
if err := h.authService.FinalizeOAuthEmailAccount(
- c.Request.Context(),
+ txCtx,
user,
strings.TrimSpace(req.InvitationCode),
strings.TrimSpace(session.ProviderType),
); err != nil {
+ _ = tx.Rollback()
if rollbackCreatedUser(err) {
return
}
@@ -1362,7 +1467,8 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
return
}
- if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), session.SessionToken, session.BrowserSessionKey); err != nil {
+ if err := consumePendingOAuthBrowserSessionTx(txCtx, tx, session); err != nil {
+ _ = tx.Rollback()
if rollbackCreatedUser(err) {
return
}
@@ -1371,6 +1477,25 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
return
}
+ if pendingOAuthCreateAccountPreCommitHook != nil {
+ if err := pendingOAuthCreateAccountPreCommitHook(txCtx, session); err != nil {
+ _ = tx.Rollback()
+ if rollbackCreatedUser(err) {
+ return
+ }
+ response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
+ return
+ }
+ }
+
+ if err := tx.Commit(); err != nil {
+ if rollbackCreatedUser(err) {
+ return
+ }
+ response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
+ return
+ }
+
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
clearCookies()
writeOAuthTokenPairResponse(c, tokenPair)
diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go
index 1013a082..008c9da2 100644
--- a/backend/internal/handler/auth_oauth_pending_flow_test.go
+++ b/backend/internal/handler/auth_oauth_pending_flow_test.go
@@ -903,6 +903,63 @@ func TestCreateOIDCOAuthAccountExistingEmailNormalizesLegacySpacingAndCase(t *te
require.Equal(t, "owner@example.com", storedSession.ResolvedEmail)
}
+func TestSendPendingOAuthVerifyCodeExistingEmailReturnsBindLoginState(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790")
+ ctx := context.Background()
+
+ existingUser, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("existing-email-send-code-session-token").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-existing-send-code-123").
+ SetBrowserSessionKey("existing-email-send-code-browser-session-key").
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "step": "email_required",
+ },
+ }).
+ SetRedirectTo("/dashboard").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"email":"owner@example.com"}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/send-verify-code", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("existing-email-send-code-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.SendPendingOAuthVerifyCode(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var payload map[string]any
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload))
+ require.Equal(t, "pending_session", payload["auth_result"])
+ require.Equal(t, "bind_login_required", payload["step"])
+ require.Equal(t, "owner@example.com", payload["email"])
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Equal(t, "adopt_existing_user_by_email", storedSession.Intent)
+ require.NotNil(t, storedSession.TargetUserID)
+ require.Equal(t, existingUser.ID, *storedSession.TargetUserID)
+ require.Equal(t, "owner@example.com", storedSession.ResolvedEmail)
+}
+
func TestCreateOIDCOAuthAccountBlocksBackendModeBeforeCreatingUser(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
emailVerifyEnabled: true,
@@ -1032,6 +1089,78 @@ func TestCreateOIDCOAuthAccountRollsBackCreatedUserWhenBindingFails(t *testing.T
require.Nil(t, storedSession.ConsumedAt)
}
+func TestCreateOIDCOAuthAccountRollsBackPostBindFailureBeforeIdentityCanCommit(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
+ emailVerifyEnabled: true,
+ emailCache: &oauthPendingFlowEmailCacheStub{
+ verificationCodes: map[string]*service.VerificationCodeData{
+ "fresh@example.com": {
+ Code: "246810",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
+ },
+ },
+ },
+ userRepoOptions: oauthPendingFlowUserRepoOptions{
+ rejectDeleteWhileAuthIdentityExists: true,
+ },
+ })
+ ctx := context.Background()
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("create-account-finalize-failure-session-token").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-finalize-failure-123").
+ SetBrowserSessionKey("create-account-finalize-failure-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ }).
+ SetRedirectTo("/profile").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ pendingOAuthCreateAccountPreCommitHook = func(context.Context, *dbent.PendingAuthSession) error {
+ return errors.New("forced post-bind failure")
+ }
+ t.Cleanup(func() {
+ pendingOAuthCreateAccountPreCommitHook = nil
+ })
+
+ body := bytes.NewBufferString(`{"email":"fresh@example.com","verify_code":"246810","password":"secret-123"}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("create-account-finalize-failure-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.CreateOIDCOAuthAccount(ginCtx)
+
+ require.Equal(t, http.StatusInternalServerError, recorder.Code)
+
+ userCount, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, userCount)
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("oidc"),
+ authidentity.ProviderKeyEQ("https://issuer.example"),
+ authidentity.ProviderSubjectEQ("oidc-finalize-failure-123"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, identityCount)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
func TestBindOIDCOAuthLoginBindsExistingUserAndConsumesSession(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
@@ -1618,7 +1747,6 @@ type oauthPendingFlowTestHandlerOptions struct {
defaultSubAssigner service.DefaultSubscriptionAssigner
totpCache service.TotpCache
totpEncryptor service.SecretEncryptor
- redeemRepoFactory func(client *dbent.Client) service.RedeemCodeRepository
userRepoOptions oauthPendingFlowUserRepoOptions
}
@@ -1685,13 +1813,7 @@ CREATE TABLE IF NOT EXISTS user_avatars (
client: client,
options: options.userRepoOptions,
}
- redeemRepo := service.RedeemCodeRepository(nil)
- if options.redeemRepoFactory != nil {
- redeemRepo = options.redeemRepoFactory(client)
- }
- if redeemRepo == nil {
- redeemRepo = &oauthPendingFlowRedeemCodeRepo{client: client}
- }
+ redeemRepo := &oauthPendingFlowRedeemCodeRepo{client: client}
var emailService *service.EmailService
if options.emailCache != nil {
emailService = service.NewEmailService(&oauthPendingFlowSettingRepoStub{
@@ -2011,14 +2133,6 @@ func (r *oauthPendingFlowRedeemCodeRepo) SumPositiveBalanceByUser(context.Contex
panic("unexpected SumPositiveBalanceByUser call")
}
-type oauthPendingFlowFailingUseRedeemRepo struct {
- *oauthPendingFlowRedeemCodeRepo
-}
-
-func (r *oauthPendingFlowFailingUseRedeemRepo) Use(context.Context, int64, int64) error {
- return errors.New("forced invitation use failure")
-}
-
func decodeJSONResponseData(t *testing.T, recorder *httptest.ResponseRecorder) map[string]any {
t.Helper()
@@ -2093,7 +2207,7 @@ func countProviderGrantRecords(
}
type oauthPendingFlowUserRepo struct {
- client *dbent.Client
+ client *dbent.Client
options oauthPendingFlowUserRepoOptions
}
diff --git a/backend/internal/service/auth_oauth_email_flow.go b/backend/internal/service/auth_oauth_email_flow.go
index ce25222c..ea558ae2 100644
--- a/backend/internal/service/auth_oauth_email_flow.go
+++ b/backend/internal/service/auth_oauth_email_flow.go
@@ -7,6 +7,9 @@ import (
"net/mail"
"strings"
"time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/redeemcode"
)
func normalizeOAuthSignupSource(signupSource string) string {
@@ -50,7 +53,7 @@ func (s *AuthService) validateOAuthRegistrationInvitation(ctx context.Context, i
if s == nil || s.settingService == nil || !s.settingService.IsInvitationCodeEnabled(ctx) {
return nil, nil
}
- if s.redeemRepo == nil {
+ if s.redeemRepo == nil && s.oauthEmailFlowClient(ctx) == nil {
return nil, ErrServiceUnavailable
}
@@ -59,7 +62,7 @@ func (s *AuthService) validateOAuthRegistrationInvitation(ctx context.Context, i
return nil, ErrInvitationCodeRequired
}
- redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode)
+ redeemCode, err := s.loadOAuthRegistrationInvitation(ctx, invitationCode)
if err != nil {
return nil, ErrInvitationCodeInvalid
}
@@ -181,12 +184,12 @@ func (s *AuthService) FinalizeOAuthEmailAccount(
return err
}
if invitationRedeemCode != nil {
- if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
+ if err := s.useOAuthRegistrationInvitation(ctx, invitationRedeemCode.ID, user.ID); err != nil {
return ErrInvitationCodeInvalid
}
}
- s.postAuthUserBootstrap(ctx, user, signupSource, false)
+ s.updateOAuthSignupSource(ctx, user.ID, signupSource)
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
return nil
@@ -211,7 +214,7 @@ func (s *AuthService) restoreOAuthRegistrationInvitation(ctx context.Context, in
if s == nil || s.settingService == nil || !s.settingService.IsInvitationCodeEnabled(ctx) {
return nil
}
- if s.redeemRepo == nil {
+ if s.redeemRepo == nil && s.oauthEmailFlowClient(ctx) == nil {
return ErrServiceUnavailable
}
@@ -220,7 +223,7 @@ func (s *AuthService) restoreOAuthRegistrationInvitation(ctx context.Context, in
return nil
}
- redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode)
+ redeemCode, err := s.loadOAuthRegistrationInvitation(ctx, invitationCode)
if err != nil {
if errors.Is(err, ErrRedeemCodeNotFound) {
return nil
@@ -234,12 +237,115 @@ func (s *AuthService) restoreOAuthRegistrationInvitation(ctx context.Context, in
redeemCode.Status = StatusUnused
redeemCode.UsedBy = nil
redeemCode.UsedAt = nil
- if err := s.redeemRepo.Update(ctx, redeemCode); err != nil {
+ if err := s.updateOAuthRegistrationInvitation(ctx, redeemCode); err != nil {
return fmt.Errorf("restore invitation code: %w", err)
}
return nil
}
+func (s *AuthService) oauthEmailFlowClient(ctx context.Context) *dbent.Client {
+ if s == nil || s.entClient == nil {
+ return nil
+ }
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ return tx.Client()
+ }
+ return s.entClient
+}
+
+func (s *AuthService) loadOAuthRegistrationInvitation(ctx context.Context, invitationCode string) (*RedeemCode, error) {
+ if client := s.oauthEmailFlowClient(ctx); client != nil {
+ entity, err := client.RedeemCode.Query().Where(redeemcode.CodeEQ(invitationCode)).Only(ctx)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, ErrRedeemCodeNotFound
+ }
+ return nil, err
+ }
+ return &RedeemCode{
+ ID: entity.ID,
+ Code: entity.Code,
+ Type: entity.Type,
+ Value: entity.Value,
+ Status: entity.Status,
+ UsedBy: entity.UsedBy,
+ UsedAt: entity.UsedAt,
+ Notes: oauthEmailFlowStringValue(entity.Notes),
+ CreatedAt: entity.CreatedAt,
+ GroupID: entity.GroupID,
+ ValidityDays: entity.ValidityDays,
+ }, nil
+ }
+ return s.redeemRepo.GetByCode(ctx, invitationCode)
+}
+
+func (s *AuthService) useOAuthRegistrationInvitation(ctx context.Context, invitationID, userID int64) error {
+ if client := s.oauthEmailFlowClient(ctx); client != nil {
+ affected, err := client.RedeemCode.Update().
+ Where(redeemcode.IDEQ(invitationID), redeemcode.StatusEQ(StatusUnused)).
+ SetStatus(StatusUsed).
+ SetUsedBy(userID).
+ SetUsedAt(time.Now().UTC()).
+ Save(ctx)
+ if err != nil {
+ return err
+ }
+ if affected == 0 {
+ return ErrRedeemCodeUsed
+ }
+ return nil
+ }
+ return s.redeemRepo.Use(ctx, invitationID, userID)
+}
+
+func (s *AuthService) updateOAuthRegistrationInvitation(ctx context.Context, code *RedeemCode) error {
+ if code == nil {
+ return nil
+ }
+ if client := s.oauthEmailFlowClient(ctx); client != nil {
+ update := client.RedeemCode.UpdateOneID(code.ID).
+ SetCode(code.Code).
+ SetType(code.Type).
+ SetValue(code.Value).
+ SetStatus(code.Status).
+ SetNotes(code.Notes).
+ SetValidityDays(code.ValidityDays)
+ if code.UsedBy != nil {
+ update = update.SetUsedBy(*code.UsedBy)
+ } else {
+ update = update.ClearUsedBy()
+ }
+ if code.UsedAt != nil {
+ update = update.SetUsedAt(*code.UsedAt)
+ } else {
+ update = update.ClearUsedAt()
+ }
+ if code.GroupID != nil {
+ update = update.SetGroupID(*code.GroupID)
+ } else {
+ update = update.ClearGroupID()
+ }
+ _, err := update.Save(ctx)
+ return err
+ }
+ return s.redeemRepo.Update(ctx, code)
+}
+
+func (s *AuthService) updateOAuthSignupSource(ctx context.Context, userID int64, signupSource string) {
+ client := s.oauthEmailFlowClient(ctx)
+ if client == nil || userID <= 0 || strings.TrimSpace(signupSource) == "" {
+ return
+ }
+ _ = client.User.UpdateOneID(userID).SetSignupSource(signupSource).Exec(ctx)
+}
+
+func oauthEmailFlowStringValue(value *string) string {
+ if value == nil {
+ return ""
+ }
+ return *value
+}
+
// ValidatePasswordCredentials checks the local password without completing the
// login flow. This is used by pending third-party account adoption flows before
// the external identity has been bound.
@@ -269,7 +375,7 @@ func (s *AuthService) ValidatePasswordCredentials(ctx context.Context, email, pa
func (s *AuthService) RecordSuccessfulLogin(ctx context.Context, userID int64) {
if s != nil && s.userRepo != nil && userID > 0 {
user, err := s.userRepo.GetByID(ctx, userID)
- if err == nil {
+ if err == nil && user != nil && !isReservedEmail(user.Email) {
s.backfillEmailIdentityOnSuccessfulLogin(ctx, user)
}
}
diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go
index 9c7d4747..e6053984 100644
--- a/backend/internal/service/user_service.go
+++ b/backend/internal/service/user_service.go
@@ -240,7 +240,7 @@ func (s *UserService) GetProfileIdentitySummaries(ctx context.Context, userID in
}
return UserIdentitySummarySet{
- Email: s.buildEmailIdentitySummary(user),
+ Email: s.buildEmailIdentitySummary(user, records),
LinuxDo: s.buildProviderIdentitySummary("linuxdo", records),
OIDC: s.buildProviderIdentitySummary("oidc", records),
WeChat: s.buildProviderIdentitySummary("wechat", records),
@@ -497,7 +497,7 @@ func compressInlineAvatar(decoded []byte) ([]byte, string, error) {
return nil, "", ErrAvatarTooLarge
}
-func (s *UserService) buildEmailIdentitySummary(user *User) UserIdentitySummary {
+func (s *UserService) buildEmailIdentitySummary(user *User, records []UserAuthIdentityRecord) UserIdentitySummary {
summary := UserIdentitySummary{
Provider: "email",
CanBind: false,
@@ -508,11 +508,34 @@ func (s *UserService) buildEmailIdentitySummary(user *User) UserIdentitySummary
return summary
}
+ filtered := filterUserAuthIdentities(records, "email")
+ if len(filtered) > 0 {
+ primary := selectPrimaryUserAuthIdentity(filtered)
+ email := strings.TrimSpace(firstStringIdentityValue(primary.Metadata, "email"))
+ if email == "" {
+ email = strings.TrimSpace(primary.ProviderSubject)
+ }
+ if email == "" || isReservedEmail(email) {
+ email = strings.TrimSpace(user.Email)
+ }
+ if email == "" || isReservedEmail(email) {
+ email = strings.TrimSpace(primary.ProviderKey)
+ }
+
+ summary.Bound = true
+ summary.BoundCount = len(filtered)
+ summary.DisplayName = email
+ summary.SubjectHint = maskEmailIdentity(email)
+ summary.ProviderKey = strings.TrimSpace(primary.ProviderKey)
+ summary.VerifiedAt = primary.VerifiedAt
+ return summary
+ }
+
+ // Compatibility fallback for legacy normal-email users that predate auth_identities backfill.
email := strings.TrimSpace(user.Email)
if email == "" || isReservedEmail(email) {
return summary
}
-
summary.Bound = true
summary.BoundCount = 1
summary.DisplayName = email
diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts
index 0f768018..89964c3c 100644
--- a/frontend/src/api/auth.ts
+++ b/frontend/src/api/auth.ts
@@ -208,6 +208,12 @@ export type PendingOAuthExchangeResponse = PendingOAuthBindLoginResponse
export interface PendingOAuthCreateAccountResponse extends OAuthTokenResponse {}
+export interface PendingOAuthSendVerifyCodeResponse extends SendVerifyCodeResponse {
+ auth_result?: string
+ provider?: string
+ redirect?: string
+}
+
export type OAuthCompletionKind = 'login' | 'bind'
export interface OAuthAdoptionDecision {
@@ -451,8 +457,8 @@ export async function sendVerifyCode(
export async function sendPendingOAuthVerifyCode(
request: SendVerifyCodeRequest
-): Promise {
- const { data } = await apiClient.post(
+): Promise {
+ const { data } = await apiClient.post(
'/auth/oauth/pending/send-verify-code',
request
)
diff --git a/frontend/src/components/user/profile/ProfileIdentityBindingsSection.vue b/frontend/src/components/user/profile/ProfileIdentityBindingsSection.vue
index ccc1cbd0..653b4e33 100644
--- a/frontend/src/components/user/profile/ProfileIdentityBindingsSection.vue
+++ b/frontend/src/components/user/profile/ProfileIdentityBindingsSection.vue
@@ -209,7 +209,12 @@ function getBindingStatus(provider: UserAuthProvider): boolean {
function getBindingStatusForUser(user: User | null | undefined, provider: UserAuthProvider): boolean {
if (provider === 'email') {
- return typeof user?.email_bound === 'boolean' ? user.email_bound : Boolean(user?.email)
+ if (typeof user?.email_bound === 'boolean') {
+ return user.email_bound
+ }
+ const nested = user?.auth_bindings?.email ?? user?.identity_bindings?.email
+ const normalized = normalizeBindingStatus(nested)
+ return normalized ?? false
}
const directFlag = user?.[`${provider}_bound` as keyof User]
diff --git a/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts b/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts
index ec4aed5d..c07acf18 100644
--- a/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts
+++ b/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts
@@ -301,4 +301,27 @@ describe('ProfileIdentityBindingsSection', () => {
expect(wrapper.get('[data-testid="profile-binding-email-status"]').text()).toBe('Bound')
expect(authStore.user?.email).toBe('bound@example.com')
})
+
+ it('keeps the email binding form visible when the user still lacks an email identity', () => {
+ const wrapper = mount(ProfileIdentityBindingsSection, {
+ global: {
+ plugins: [pinia],
+ },
+ props: {
+ user: createUser({
+ email: 'legacy@example.com',
+ email_bound: false,
+ auth_bindings: {
+ email: { bound: false },
+ },
+ }),
+ linuxdoEnabled: false,
+ oidcEnabled: false,
+ wechatEnabled: false,
+ },
+ })
+
+ expect(wrapper.get('[data-testid="profile-binding-email-status"]').text()).toBe('Not bound')
+ expect(wrapper.get('[data-testid="profile-binding-email-input"]').exists()).toBe(true)
+ })
})
diff --git a/frontend/src/views/auth/EmailVerifyView.vue b/frontend/src/views/auth/EmailVerifyView.vue
index d7bf6b7a..01829765 100644
--- a/frontend/src/views/auth/EmailVerifyView.vue
+++ b/frontend/src/views/auth/EmailVerifyView.vue
@@ -179,6 +179,8 @@ import { useAuthStore, useAppStore } from '@/stores'
import {
persistOAuthTokenContext,
getPublicSettings,
+ isOAuthLoginCompletion,
+ type PendingOAuthSendVerifyCodeResponse,
sendPendingOAuthVerifyCode,
sendVerifyCode,
} from '@/api/auth'
@@ -216,10 +218,13 @@ type PendingAuthSessionSummary = {
redirect?: string
}
type PendingOAuthCreateAccountResponse = {
+ auth_result?: string
access_token: string
refresh_token?: string
expires_in?: number
token_type?: string
+ provider?: string
+ redirect?: string
}
const email = ref('')
@@ -353,6 +358,46 @@ function onTurnstileError(): void {
errors.value.turnstile = t('auth.turnstileFailed')
}
+function isPendingOAuthFlow(): boolean {
+ return Boolean(pendingProvider.value.trim())
+}
+
+function shouldBypassRegistrationEmailPolicy(): boolean {
+ return isPendingOAuthFlow() || Boolean(pendingAuthToken.value.trim())
+}
+
+function resolvePendingOAuthCallbackRoute(provider: string): string {
+ switch (provider.trim().toLowerCase()) {
+ case 'linuxdo':
+ return '/auth/linuxdo/callback'
+ case 'oidc':
+ return '/auth/oidc/callback'
+ case 'wechat':
+ return '/auth/wechat/callback'
+ default:
+ return '/auth/callback'
+ }
+}
+
+function isPendingOAuthSessionResponse(data: PendingOAuthCreateAccountResponse): boolean {
+ return data.auth_result === 'pending_session'
+}
+
+function getPendingOAuthSendCodeSessionResponse(
+ data: PendingOAuthSendVerifyCodeResponse,
+): PendingOAuthSendVerifyCodeResponse | null {
+ return data.auth_result === 'pending_session' ? data : null
+}
+
+function persistPendingOAuthSession(provider: string, redirect?: string): void {
+ authStore.setPendingAuthSession({
+ token: pendingAuthToken.value,
+ token_field: pendingAuthTokenField.value,
+ provider: provider.trim() || pendingProvider.value.trim(),
+ redirect: redirect || pendingRedirect.value || undefined,
+ })
+}
+
// ==================== Send Code ====================
async function sendCode(): Promise {
@@ -360,7 +405,7 @@ async function sendCode(): Promise {
errorMessage.value = ''
try {
- if (!pendingAuthToken.value && !isRegistrationEmailSuffixAllowed(email.value, registrationEmailSuffixWhitelist.value)) {
+ if (!shouldBypassRegistrationEmailPolicy() && !isRegistrationEmailSuffixAllowed(email.value, registrationEmailSuffixWhitelist.value)) {
errorMessage.value = buildEmailSuffixNotAllowedMessage()
appStore.showError(errorMessage.value)
return
@@ -372,10 +417,25 @@ async function sendCode(): Promise {
// 优先使用重发时新获取的 token(因为初始 token 可能已被使用)
turnstile_token: resendTurnstileToken.value || initialTurnstileToken.value || undefined
} as Parameters[0]
- const response = pendingAuthToken.value
+ const response = isPendingOAuthFlow()
? await sendPendingOAuthVerifyCode(requestPayload)
: await sendVerifyCode(requestPayload)
+ const pendingSendCodeSession = isPendingOAuthFlow()
+ ? getPendingOAuthSendCodeSessionResponse(response as PendingOAuthSendVerifyCodeResponse)
+ : null
+ if (pendingSendCodeSession) {
+ sessionStorage.removeItem('register_data')
+ persistPendingOAuthSession(
+ pendingSendCodeSession.provider || pendingProvider.value,
+ pendingSendCodeSession.redirect,
+ )
+ await router.push(
+ resolvePendingOAuthCallbackRoute(pendingSendCodeSession.provider || pendingProvider.value),
+ )
+ return
+ }
+
codeSent.value = true
startCountdown(response.countdown)
@@ -438,13 +498,13 @@ async function handleVerify(): Promise {
isLoading.value = true
try {
- if (!isRegistrationEmailSuffixAllowed(email.value, registrationEmailSuffixWhitelist.value)) {
+ if (!shouldBypassRegistrationEmailPolicy() && !isRegistrationEmailSuffixAllowed(email.value, registrationEmailSuffixWhitelist.value)) {
errorMessage.value = buildEmailSuffixNotAllowedMessage()
appStore.showError(errorMessage.value)
return
}
- if (pendingProvider.value) {
+ if (isPendingOAuthFlow()) {
const { data } = await apiClient.post(
'/auth/oauth/pending/create-account',
{
@@ -456,6 +516,16 @@ async function handleVerify(): Promise {
adopt_avatar: pendingAdoptionDecision.value?.adoptAvatar
}
)
+ if (isPendingOAuthSessionResponse(data)) {
+ sessionStorage.removeItem('register_data')
+ persistPendingOAuthSession(data.provider || pendingProvider.value, data.redirect)
+ await router.push(resolvePendingOAuthCallbackRoute(data.provider || pendingProvider.value))
+ return
+ }
+ if (!isOAuthLoginCompletion(data)) {
+ throw new Error(t('auth.verifyFailed'))
+ }
+
persistOAuthTokenContext(data)
await authStore.setToken(data.access_token)
authStore.clearPendingAuthSession?.()
diff --git a/frontend/src/views/auth/__tests__/EmailVerifyView.spec.ts b/frontend/src/views/auth/__tests__/EmailVerifyView.spec.ts
index c231d6e7..9f67a994 100644
--- a/frontend/src/views/auth/__tests__/EmailVerifyView.spec.ts
+++ b/frontend/src/views/auth/__tests__/EmailVerifyView.spec.ts
@@ -8,6 +8,7 @@ const {
showErrorMock,
registerMock,
setTokenMock,
+ setPendingAuthSessionMock,
clearPendingAuthSessionMock,
getPublicSettingsMock,
sendVerifyCodeMock,
@@ -21,6 +22,7 @@ const {
showErrorMock: vi.fn(),
registerMock: vi.fn(),
setTokenMock: vi.fn(),
+ setPendingAuthSessionMock: vi.fn(),
clearPendingAuthSessionMock: vi.fn(),
getPublicSettingsMock: vi.fn(),
sendVerifyCodeMock: vi.fn(),
@@ -68,6 +70,7 @@ vi.mock('@/stores', () => ({
pendingAuthSession: authStoreState.pendingAuthSession,
register: (...args: any[]) => registerMock(...args),
setToken: (...args: any[]) => setTokenMock(...args),
+ setPendingAuthSession: (...args: any[]) => setPendingAuthSessionMock(...args),
clearPendingAuthSession: (...args: any[]) => clearPendingAuthSessionMock(...args),
}),
useAppStore: () => ({
@@ -100,6 +103,7 @@ describe('EmailVerifyView', () => {
showErrorMock.mockReset()
registerMock.mockReset()
setTokenMock.mockReset()
+ setPendingAuthSessionMock.mockReset()
clearPendingAuthSessionMock.mockReset()
getPublicSettingsMock.mockReset()
sendVerifyCodeMock.mockReset()
@@ -196,6 +200,97 @@ describe('EmailVerifyView', () => {
expect(showErrorMock).not.toHaveBeenCalled()
})
+ it('uses the pending oauth verify-code endpoint when auth store only carries the pending provider', async () => {
+ authStoreState.pendingAuthSession = {
+ token: '',
+ token_field: 'pending_oauth_token',
+ provider: 'oidc',
+ redirect: '/profile',
+ }
+ getPublicSettingsMock.mockResolvedValue({
+ turnstile_enabled: false,
+ turnstile_site_key: '',
+ site_name: 'Sub2API',
+ registration_email_suffix_whitelist: ['allowed.com'],
+ })
+ sessionStorage.setItem(
+ 'register_data',
+ JSON.stringify({
+ email: 'fresh@example.com',
+ password: 'secret-123',
+ })
+ )
+
+ mount(EmailVerifyView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ TurnstileWidget: true,
+ transition: false,
+ },
+ },
+ })
+
+ await flushPromises()
+
+ expect(sendPendingOAuthVerifyCodeMock).toHaveBeenCalledWith({
+ email: 'fresh@example.com',
+ pending_oauth_token: undefined,
+ })
+ expect(sendVerifyCodeMock).not.toHaveBeenCalled()
+ expect(showErrorMock).not.toHaveBeenCalled()
+ })
+
+ it('returns to the oauth callback flow when pending send-code detects an existing account email', async () => {
+ authStoreState.pendingAuthSession = {
+ token: '',
+ token_field: 'pending_oauth_token',
+ provider: 'oidc',
+ redirect: '/profile/security',
+ }
+ getPublicSettingsMock.mockResolvedValue({
+ turnstile_enabled: false,
+ turnstile_site_key: '',
+ site_name: 'Sub2API',
+ registration_email_suffix_whitelist: ['allowed.com'],
+ })
+ sendPendingOAuthVerifyCodeMock.mockResolvedValue({
+ auth_result: 'pending_session',
+ provider: 'oidc',
+ redirect: '/profile/security',
+ })
+ sessionStorage.setItem(
+ 'register_data',
+ JSON.stringify({
+ email: 'fresh@example.com',
+ password: 'secret-123',
+ })
+ )
+
+ mount(EmailVerifyView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ TurnstileWidget: true,
+ transition: false,
+ },
+ },
+ })
+
+ await flushPromises()
+
+ expect(setPendingAuthSessionMock).toHaveBeenCalledWith({
+ token: '',
+ token_field: 'pending_oauth_token',
+ provider: 'oidc',
+ redirect: '/profile/security',
+ })
+ expect(pushMock).toHaveBeenCalledWith('/auth/oidc/callback')
+ expect(showErrorMock).not.toHaveBeenCalled()
+ })
+
it('submits pending auth account creation when session storage has no pending metadata but auth store does', async () => {
authStoreState.pendingAuthSession = {
token: 'pending-token-1',
@@ -252,6 +347,70 @@ describe('EmailVerifyView', () => {
expect(registerMock).not.toHaveBeenCalled()
})
+ it('returns to the oauth callback flow when pending account creation becomes bind-login', async () => {
+ authStoreState.pendingAuthSession = {
+ token: '',
+ token_field: 'pending_oauth_token',
+ provider: 'oidc',
+ redirect: '/profile/security',
+ }
+ getPublicSettingsMock.mockResolvedValue({
+ turnstile_enabled: false,
+ turnstile_site_key: '',
+ site_name: 'Sub2API',
+ registration_email_suffix_whitelist: ['allowed.com'],
+ })
+ sessionStorage.setItem(
+ 'register_data',
+ JSON.stringify({
+ email: 'fresh@example.com',
+ password: 'secret-123',
+ })
+ )
+ apiClientPostMock.mockResolvedValue({
+ data: {
+ auth_result: 'pending_session',
+ provider: 'oidc',
+ step: 'bind_login_required',
+ redirect: '/profile/security',
+ email: 'fresh@example.com',
+ },
+ })
+
+ const wrapper = mount(EmailVerifyView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ TurnstileWidget: true,
+ transition: false,
+ },
+ },
+ })
+
+ await flushPromises()
+ await wrapper.get('#code').setValue('123456')
+ await wrapper.get('form').trigger('submit.prevent')
+ await flushPromises()
+
+ expect(apiClientPostMock).toHaveBeenCalledWith('/auth/oauth/pending/create-account', {
+ email: 'fresh@example.com',
+ password: 'secret-123',
+ verify_code: '123456',
+ })
+ expect(setPendingAuthSessionMock).toHaveBeenCalledWith({
+ token: '',
+ token_field: 'pending_oauth_token',
+ provider: 'oidc',
+ redirect: '/profile/security',
+ })
+ expect(pushMock).toHaveBeenCalledWith('/auth/oidc/callback')
+ expect(setTokenMock).not.toHaveBeenCalled()
+ expect(persistOAuthTokenContextMock).not.toHaveBeenCalled()
+ expect(clearPendingAuthSessionMock).not.toHaveBeenCalled()
+ expect(showSuccessMock).not.toHaveBeenCalled()
+ })
+
it('keeps the normal email registration flow unchanged', async () => {
sessionStorage.setItem(
'register_data',
--
GitLab
From f3986501663c864594f1444440f5e36199595983 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 11:00:08 +0800
Subject: [PATCH 126/261] fix: harden oidc compat email and email bind tx
---
backend/internal/handler/auth_oidc_oauth.go | 69 ++++++-
.../internal/handler/auth_oidc_oauth_test.go | 121 +++++++++++++
.../internal/service/auth_email_binding.go | 169 ++++++++++++++++++
.../service/auth_service_email_bind_test.go | 71 ++++++++
4 files changed, 424 insertions(+), 6 deletions(-)
diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go
index 5901a953..6d19e9d6 100644
--- a/backend/internal/handler/auth_oidc_oauth.go
+++ b/backend/internal/handler/auth_oidc_oauth.go
@@ -19,6 +19,7 @@ import (
"strings"
"time"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
@@ -323,18 +324,13 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
if emailVerified == nil {
emailVerified = idClaims.EmailVerified
}
- if cfg.RequireEmailVerified {
- if emailVerified == nil || !*emailVerified {
- redirectOAuthError(c, frontendCallback, "email_not_verified", "email is not verified", "")
- return
- }
- }
if userInfoClaims.Subject != "" && idClaims.Subject != "" && strings.TrimSpace(userInfoClaims.Subject) != strings.TrimSpace(idClaims.Subject) {
redirectOAuthError(c, frontendCallback, "subject_mismatch", "userinfo subject does not match id_token", "")
return
}
identityKey := oidcIdentityKey(issuer, subject)
+ compatEmail := strings.TrimSpace(firstNonEmpty(userInfoClaims.Email, idClaims.Email))
email := oidcSyntheticEmailFromIdentityKey(identityKey)
username := firstNonEmpty(
userInfoClaims.Username,
@@ -357,6 +353,9 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
"suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username),
"suggested_avatar_url": userInfoClaims.AvatarURL,
}
+ if compatEmail != "" && !strings.EqualFold(strings.TrimSpace(compatEmail), strings.TrimSpace(email)) {
+ upstreamClaims["compat_email"] = compatEmail
+ }
if intent == oauthIntentBindCurrentUser {
targetUserID, err := h.readOAuthBindUserIDFromCookie(c, oidcOAuthBindUserCookieName)
if err != nil {
@@ -416,6 +415,40 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
return
}
+ compatEmailUser, err := h.findOIDCCompatEmailUser(c.Request.Context(), compatEmail)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+ if compatEmailUser != nil {
+ if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: "adopt_existing_user_by_email",
+ Identity: identityRef,
+ TargetUserID: &compatEmailUser.ID,
+ ResolvedEmail: compatEmailUser.Email,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
+ CompletionResponse: map[string]any{
+ "redirect": redirectTo,
+ "step": "bind_login_required",
+ "email": compatEmailUser.Email,
+ },
+ }); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+ return
+ }
+
+ if cfg.RequireEmailVerified {
+ if emailVerified == nil || !*emailVerified {
+ redirectOAuthError(c, frontendCallback, "email_not_verified", "email is not verified", "")
+ return
+ }
+ }
+
if h.isForceEmailOnThirdPartySignup(c.Request.Context()) {
if err := h.createOAuthEmailRequiredPendingSession(c, identityRef, redirectTo, browserSessionKey, upstreamClaims); err != nil {
redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
@@ -473,6 +506,30 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
redirectToFrontendCallback(c, frontendCallback)
}
+func (h *AuthHandler) findOIDCCompatEmailUser(ctx context.Context, email string) (*dbent.User, error) {
+ client := h.entClient()
+ if client == nil {
+ return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+
+ email = strings.TrimSpace(strings.ToLower(email))
+ if email == "" ||
+ strings.HasSuffix(email, service.LinuxDoConnectSyntheticEmailDomain) ||
+ strings.HasSuffix(email, service.OIDCConnectSyntheticEmailDomain) ||
+ strings.HasSuffix(email, service.WeChatConnectSyntheticEmailDomain) {
+ return nil, nil
+ }
+
+ userEntity, err := findUserByNormalizedEmail(ctx, client, email)
+ if err != nil {
+ if errors.Is(err, service.ErrUserNotFound) {
+ return nil, nil
+ }
+ return nil, infraerrors.InternalServer("COMPAT_EMAIL_LOOKUP_FAILED", "failed to look up compat email user").WithCause(err)
+ }
+ return userEntity, nil
+}
+
type completeOIDCOAuthRequest struct {
InvitationCode string `json:"invitation_code" binding:"required"`
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
diff --git a/backend/internal/handler/auth_oidc_oauth_test.go b/backend/internal/handler/auth_oidc_oauth_test.go
index ba736db2..5cd8e0ea 100644
--- a/backend/internal/handler/auth_oidc_oauth_test.go
+++ b/backend/internal/handler/auth_oidc_oauth_test.go
@@ -245,6 +245,127 @@ func TestOIDCOAuthCallbackCreatesLoginPendingSessionForExistingUser(t *testing.T
require.Nil(t, completion["error"])
}
+func TestOIDCOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *testing.T) {
+ cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{
+ Subject: "oidc-subject-compat",
+ PreferredUsername: "oidc_compat",
+ DisplayName: "OIDC Compat Display",
+ AvatarURL: "https://cdn.example/oidc-compat.png",
+ Email: "legacy@example.com",
+ EmailVerified: true,
+ })
+ defer cleanup()
+
+ handler, client := newOIDCOAuthHandlerAndClient(t, false, cfg)
+ defer client.Close()
+
+ ctx := context.Background()
+ existingUser, err := client.User.Create().
+ SetEmail("legacy@example.com").
+ SetUsername("legacy-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-compat", nil)
+ req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-compat"))
+ req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/dashboard"))
+ req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-compat"))
+ req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-subject-compat"))
+ req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-compat"))
+ c.Request = req
+
+ handler.OIDCOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/oidc/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "adopt_existing_user_by_email", session.Intent)
+ require.NotNil(t, session.TargetUserID)
+ require.Equal(t, existingUser.ID, *session.TargetUserID)
+ require.Equal(t, existingUser.Email, session.ResolvedEmail)
+ require.Equal(t, "legacy@example.com", session.UpstreamIdentityClaims["compat_email"])
+
+ completion := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
+ require.Equal(t, "/dashboard", completion["redirect"])
+ require.Equal(t, "bind_login_required", completion["step"])
+ require.Equal(t, existingUser.Email, completion["email"])
+ _, hasAccessToken := completion["access_token"]
+ require.False(t, hasAccessToken)
+}
+
+func TestOIDCOAuthCallbackAllowsCompatEmailBindWhenUpstreamEmailIsUnverified(t *testing.T) {
+ cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{
+ Subject: "oidc-subject-unverified-compat",
+ PreferredUsername: "oidc_unverified",
+ DisplayName: "OIDC Unverified Compat Display",
+ AvatarURL: "https://cdn.example/oidc-unverified.png",
+ Email: "owner@example.com",
+ EmailVerified: false,
+ })
+ defer cleanup()
+ cfg.RequireEmailVerified = true
+
+ handler, client := newOIDCOAuthHandlerAndClient(t, false, cfg)
+ defer client.Close()
+
+ ctx := context.Background()
+ existingUser, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-unverified-compat", nil)
+ req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-unverified-compat"))
+ req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/settings/connections"))
+ req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-unverified-compat"))
+ req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-subject-unverified-compat"))
+ req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-unverified-compat"))
+ c.Request = req
+
+ handler.OIDCOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/oidc/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "adopt_existing_user_by_email", session.Intent)
+ require.NotNil(t, session.TargetUserID)
+ require.Equal(t, existingUser.ID, *session.TargetUserID)
+ require.Equal(t, existingUser.Email, session.ResolvedEmail)
+ require.Equal(t, "owner@example.com", session.UpstreamIdentityClaims["compat_email"])
+
+ completion := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
+ require.Equal(t, "/settings/connections", completion["redirect"])
+ require.Equal(t, "bind_login_required", completion["step"])
+ require.Equal(t, existingUser.Email, completion["email"])
+}
+
func TestOIDCOAuthCallbackCreatesInvitationPendingSessionWhenSignupRequiresInvite(t *testing.T) {
cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{
Subject: "oidc-subject-invite",
diff --git a/backend/internal/service/auth_email_binding.go b/backend/internal/service/auth_email_binding.go
index b999660b..58f8e647 100644
--- a/backend/internal/service/auth_email_binding.go
+++ b/backend/internal/service/auth_email_binding.go
@@ -6,7 +6,10 @@ import (
"fmt"
"net/mail"
"strings"
+ "time"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
@@ -55,6 +58,13 @@ func (s *AuthService) BindEmailIdentity(
}
firstRealEmailBind := !hasBindableEmailIdentitySubject(currentUser.Email)
+ if firstRealEmailBind && s.entClient != nil {
+ if err := s.bindEmailIdentityWithDefaultsTx(ctx, currentUser, normalizedEmail, hashedPassword); err != nil {
+ return nil, err
+ }
+ return currentUser, nil
+ }
+
currentUser.Email = normalizedEmail
currentUser.PasswordHash = hashedPassword
if err := s.userRepo.Update(ctx, currentUser); err != nil {
@@ -126,3 +136,162 @@ func hasBindableEmailIdentitySubject(email string) bool {
normalized := strings.ToLower(strings.TrimSpace(email))
return normalized != "" && !isReservedEmail(normalized)
}
+
+func (s *AuthService) bindEmailIdentityWithDefaultsTx(
+ ctx context.Context,
+ currentUser *User,
+ email string,
+ hashedPassword string,
+) error {
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ return s.bindEmailIdentityWithDefaults(ctx, tx.Client(), currentUser, email, hashedPassword)
+ }
+
+ tx, err := s.entClient.Tx(ctx)
+ if err != nil {
+ return ErrServiceUnavailable
+ }
+ defer func() { _ = tx.Rollback() }()
+
+ txCtx := dbent.NewTxContext(ctx, tx)
+ if err := s.bindEmailIdentityWithDefaults(txCtx, tx.Client(), currentUser, email, hashedPassword); err != nil {
+ return err
+ }
+ if err := tx.Commit(); err != nil {
+ return ErrServiceUnavailable
+ }
+ return nil
+}
+
+func (s *AuthService) bindEmailIdentityWithDefaults(
+ ctx context.Context,
+ client *dbent.Client,
+ currentUser *User,
+ email string,
+ hashedPassword string,
+) error {
+ if client == nil || currentUser == nil || currentUser.ID <= 0 {
+ return ErrServiceUnavailable
+ }
+
+ oldEmail := currentUser.Email
+ if _, err := client.User.UpdateOneID(currentUser.ID).
+ SetEmail(email).
+ SetPasswordHash(hashedPassword).
+ Save(ctx); err != nil {
+ if dbent.IsConstraintError(err) {
+ return ErrEmailExists
+ }
+ return ErrServiceUnavailable
+ }
+
+ if err := replaceBoundEmailAuthIdentityWithClient(ctx, client, currentUser.ID, oldEmail, email, "auth_service_email_bind"); err != nil {
+ if errors.Is(err, ErrEmailExists) {
+ return ErrEmailExists
+ }
+ return ErrServiceUnavailable
+ }
+
+ if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, currentUser.ID, "email"); err != nil {
+ return fmt.Errorf("apply email first bind defaults: %w", err)
+ }
+
+ updatedUser, err := client.User.Get(ctx, currentUser.ID)
+ if err != nil {
+ return ErrServiceUnavailable
+ }
+ currentUser.Email = updatedUser.Email
+ currentUser.PasswordHash = updatedUser.PasswordHash
+ currentUser.Balance = updatedUser.Balance
+ currentUser.Concurrency = updatedUser.Concurrency
+ currentUser.UpdatedAt = updatedUser.UpdatedAt
+ return nil
+}
+
+func replaceBoundEmailAuthIdentityWithClient(
+ ctx context.Context,
+ client *dbent.Client,
+ userID int64,
+ oldEmail string,
+ newEmail string,
+ source string,
+) error {
+ newSubject := normalizeBoundEmailAuthIdentitySubject(newEmail)
+ if err := ensureBoundEmailAuthIdentityWithClient(ctx, client, userID, newSubject, source); err != nil {
+ return err
+ }
+
+ oldSubject := normalizeBoundEmailAuthIdentitySubject(oldEmail)
+ if oldSubject == "" || oldSubject == newSubject {
+ return nil
+ }
+
+ _, err := client.AuthIdentity.Delete().
+ Where(
+ authidentity.UserIDEQ(userID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ(oldSubject),
+ ).
+ Exec(ctx)
+ return err
+}
+
+func ensureBoundEmailAuthIdentityWithClient(
+ ctx context.Context,
+ client *dbent.Client,
+ userID int64,
+ subject string,
+ source string,
+) error {
+ if client == nil || userID <= 0 || subject == "" {
+ return nil
+ }
+
+ if strings.TrimSpace(source) == "" {
+ source = "auth_service_email_bind"
+ }
+
+ if err := client.AuthIdentity.Create().
+ SetUserID(userID).
+ SetProviderType("email").
+ SetProviderKey("email").
+ SetProviderSubject(subject).
+ SetVerifiedAt(time.Now().UTC()).
+ SetMetadata(map[string]any{"source": strings.TrimSpace(source)}).
+ OnConflictColumns(
+ authidentity.FieldProviderType,
+ authidentity.FieldProviderKey,
+ authidentity.FieldProviderSubject,
+ ).
+ DoNothing().
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ(subject),
+ ).
+ Only(ctx)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil
+ }
+ return err
+ }
+ if identity.UserID != userID {
+ return ErrEmailExists
+ }
+ return nil
+}
+
+func normalizeBoundEmailAuthIdentitySubject(email string) string {
+ normalized := strings.ToLower(strings.TrimSpace(email))
+ if normalized == "" || isReservedEmail(normalized) {
+ return ""
+ }
+ return normalized
+}
diff --git a/backend/internal/service/auth_service_email_bind_test.go b/backend/internal/service/auth_service_email_bind_test.go
index 899a736d..fd5f499b 100644
--- a/backend/internal/service/auth_service_email_bind_test.go
+++ b/backend/internal/service/auth_service_email_bind_test.go
@@ -5,6 +5,7 @@ package service_test
import (
"context"
"database/sql"
+ "errors"
"testing"
"time"
@@ -34,6 +35,20 @@ func (s *emailBindDefaultSubAssignerStub) AssignOrExtendSubscription(
return &service.UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, false, nil
}
+type flakyEmailBindDefaultSubAssignerStub struct {
+ err error
+ calls []*service.AssignSubscriptionInput
+}
+
+func (s *flakyEmailBindDefaultSubAssignerStub) AssignOrExtendSubscription(
+ _ context.Context,
+ input *service.AssignSubscriptionInput,
+) (*service.UserSubscription, bool, error) {
+ cloned := *input
+ s.calls = append(s.calls, &cloned)
+ return nil, false, s.err
+}
+
func newAuthServiceForEmailBind(
t *testing.T,
settings map[string]string,
@@ -187,6 +202,62 @@ func TestAuthServiceBindEmailIdentity_RejectsExistingEmailOnAnotherUser(t *testi
require.Equal(t, 0, countProviderGrantRecords(t, client, sourceUser.ID, "email", "first_bind"))
}
+func TestAuthServiceBindEmailIdentity_RollsBackWhenFirstBindDefaultsFail(t *testing.T) {
+ assigner := &flakyEmailBindDefaultSubAssignerStub{err: errors.New("temporary assign failure")}
+ cache := &emailBindCacheStub{
+ data: &service.VerificationCodeData{
+ Code: "123456",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
+ },
+ }
+ svc, _, client := newAuthServiceForEmailBind(t, map[string]string{
+ service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
+ service.SettingKeyAuthSourceDefaultEmailConcurrency: "4",
+ service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
+ service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
+ }, cache, assigner)
+
+ ctx := context.Background()
+ originalEmail := "legacy-rollback" + service.LinuxDoConnectSyntheticEmailDomain
+ user, err := client.User.Create().
+ SetEmail(originalEmail).
+ SetUsername("legacy-rollback").
+ SetPasswordHash("old-hash").
+ SetBalance(2.5).
+ SetConcurrency(1).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, "rollback@example.com", "123456", "new-password")
+ require.ErrorContains(t, err, "apply email first bind defaults")
+ require.ErrorContains(t, err, "temporary assign failure")
+ require.Nil(t, updatedUser)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, originalEmail, storedUser.Email)
+ require.Equal(t, "old-hash", storedUser.PasswordHash)
+ require.Equal(t, 2.5, storedUser.Balance)
+ require.Equal(t, 1, storedUser.Concurrency)
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("rollback@example.com"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 0, identityCount)
+
+ require.Len(t, assigner.calls, 1)
+ require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
+}
+
func TestAuthServiceBindEmailIdentity_RejectsReservedEmail(t *testing.T) {
cache := &emailBindCacheStub{
data: &service.VerificationCodeData{
--
GitLab
From 33b208ab6faba4144300b1a9728d673063c95ae5 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 11:00:18 +0800
Subject: [PATCH 127/261] fix: restore legacy oauth callback fragment
compatibility
---
.../src/views/auth/LinuxDoCallbackView.vue | 91 +++++++++++++++----
frontend/src/views/auth/OidcCallbackView.vue | 91 +++++++++++++++----
.../__tests__/LinuxDoCallbackView.spec.ts | 68 ++++++++++++++
.../auth/__tests__/OidcCallbackView.spec.ts | 68 ++++++++++++++
4 files changed, 282 insertions(+), 36 deletions(-)
diff --git a/frontend/src/views/auth/LinuxDoCallbackView.vue b/frontend/src/views/auth/LinuxDoCallbackView.vue
index 735c6582..a775075d 100644
--- a/frontend/src/views/auth/LinuxDoCallbackView.vue
+++ b/frontend/src/views/auth/LinuxDoCallbackView.vue
@@ -249,6 +249,7 @@ import {
login2FA,
persistOAuthTokenContext,
type OAuthAdoptionDecision,
+ type OAuthTokenResponse,
type PendingOAuthExchangeResponse
} from '@/api/auth'
@@ -278,6 +279,7 @@ const pendingAccountAction = ref<'none' | 'create_account' | 'bind_login'>('none
const pendingAccountEmail = ref('')
const bindLoginEmail = ref('')
const bindLoginPassword = ref('')
+const legacyPendingOAuthToken = ref('')
const accountActionError = ref('')
const canReturnToCreateAccount = ref(false)
const bindSuccessMessage = t('profile.authBindings.bindSuccess')
@@ -315,6 +317,30 @@ function parseFragmentParams(): URLSearchParams {
return new URLSearchParams(hash)
}
+function readLegacyFragmentLogin(params: URLSearchParams): OAuthTokenResponse | null {
+ const accessToken = params.get('access_token')?.trim() || ''
+ if (!accessToken) {
+ return null
+ }
+
+ const completion: OAuthTokenResponse = {
+ access_token: accessToken
+ }
+ const refreshToken = params.get('refresh_token')?.trim() || ''
+ if (refreshToken) {
+ completion.refresh_token = refreshToken
+ }
+ const expiresIn = Number.parseInt(params.get('expires_in')?.trim() || '', 10)
+ if (Number.isFinite(expiresIn) && expiresIn > 0) {
+ completion.expires_in = expiresIn
+ }
+ const tokenType = params.get('token_type')?.trim() || ''
+ if (tokenType) {
+ completion.token_type = tokenType
+ }
+ return completion
+}
+
function sanitizeRedirectPath(path: string | null | undefined): string {
if (!path) return '/dashboard'
if (!path.startsWith('/')) return '/dashboard'
@@ -521,10 +547,18 @@ async function handleSubmitInvitation() {
isSubmitting.value = true
try {
- const tokenData = await completeLinuxDoOAuthRegistration(
- invitationCode.value.trim(),
- currentAdoptionDecision()
- )
+ const tokenData = legacyPendingOAuthToken.value
+ ? (
+ await apiClient.post('/auth/oauth/linuxdo/complete-registration', {
+ pending_oauth_token: legacyPendingOAuthToken.value,
+ invitation_code: invitationCode.value.trim(),
+ ...serializeAdoptionDecision(currentAdoptionDecision())
+ })
+ ).data
+ : await completeLinuxDoOAuthRegistration(
+ invitationCode.value.trim(),
+ currentAdoptionDecision()
+ )
persistOAuthTokenContext(tokenData)
await authStore.setToken(tokenData.access_token)
appStore.showSuccess(t('auth.loginSuccess'))
@@ -621,51 +655,72 @@ async function handleSubmitTotpChallenge() {
onMounted(async () => {
const params = parseFragmentParams()
+ const legacyLogin = readLegacyFragmentLogin(params)
+ const legacyPendingToken = params.get('pending_oauth_token')?.trim() || ''
const error = params.get('error')
const errorDesc = params.get('error_description') || params.get('error_message') || ''
-
- if (error) {
- errorMessage.value = errorDesc || error
- appStore.showError(errorMessage.value)
- isProcessing.value = false
- return
- }
+ const redirect = sanitizeRedirectPath(
+ params.get('redirect') || (route.query.redirect as string | undefined) || '/dashboard'
+ )
try {
+ if (legacyLogin) {
+ persistOAuthTokenContext(legacyLogin)
+ await authStore.setToken(legacyLogin.access_token)
+ appStore.showSuccess(t('auth.loginSuccess'))
+ await router.replace(redirect)
+ return
+ }
+
+ if (error === 'invitation_required' && legacyPendingToken) {
+ legacyPendingOAuthToken.value = legacyPendingToken
+ redirectTo.value = redirect
+ needsInvitation.value = true
+ isProcessing.value = false
+ return
+ }
+
+ if (error) {
+ errorMessage.value = errorDesc || error
+ appStore.showError(errorMessage.value)
+ isProcessing.value = false
+ return
+ }
+
const completion = await exchangePendingOAuthCompletion()
- const redirect = sanitizeRedirectPath(
+ const completionRedirect = sanitizeRedirectPath(
completion.redirect || (route.query.redirect as string | undefined) || '/dashboard'
)
applyAdoptionSuggestionState(completion)
- redirectTo.value = redirect
+ redirectTo.value = completionRedirect
if (completion.error === 'invitation_required') {
needsInvitation.value = true
isProcessing.value = false
- persistPendingAuthSession(redirect)
+ persistPendingAuthSession(completionRedirect)
return
}
if (applyTotpChallenge(completion as LinuxDoPendingActionResponse)) {
- persistPendingAuthSession(redirect)
+ persistPendingAuthSession(completionRedirect)
return
}
applyPendingAccountAction(completion as LinuxDoPendingActionResponse)
if (pendingAccountAction.value !== 'none') {
isProcessing.value = false
- persistPendingAuthSession(redirect)
+ persistPendingAuthSession(completionRedirect)
return
}
if (adoptionRequired.value && hasSuggestedProfile(completion)) {
needsAdoptionConfirmation.value = true
isProcessing.value = false
- persistPendingAuthSession(redirect)
+ persistPendingAuthSession(completionRedirect)
return
}
- await finalizeCompletion(completion, redirect)
+ await finalizeCompletion(completion, completionRedirect)
} catch (e: unknown) {
clearPendingAuthSession()
errorMessage.value = getRequestErrorMessage(e, t('auth.loginFailed'))
diff --git a/frontend/src/views/auth/OidcCallbackView.vue b/frontend/src/views/auth/OidcCallbackView.vue
index 019cab54..e15e752f 100644
--- a/frontend/src/views/auth/OidcCallbackView.vue
+++ b/frontend/src/views/auth/OidcCallbackView.vue
@@ -259,6 +259,7 @@ import {
login2FA,
persistOAuthTokenContext,
type OAuthAdoptionDecision,
+ type OAuthTokenResponse,
type PendingOAuthExchangeResponse
} from '@/api/auth'
@@ -287,6 +288,7 @@ const pendingAccountAction = ref<'none' | 'create_account' | 'bind_login'>('none
const pendingAccountEmail = ref('')
const bindLoginEmail = ref('')
const bindLoginPassword = ref('')
+const legacyPendingOAuthToken = ref('')
const accountActionError = ref('')
const canReturnToCreateAccount = ref(false)
const bindSuccessMessage = t('profile.authBindings.bindSuccess')
@@ -331,6 +333,30 @@ function parseFragmentParams(): URLSearchParams {
return new URLSearchParams(hash)
}
+function readLegacyFragmentLogin(params: URLSearchParams): OAuthTokenResponse | null {
+ const accessToken = params.get('access_token')?.trim() || ''
+ if (!accessToken) {
+ return null
+ }
+
+ const completion: OAuthTokenResponse = {
+ access_token: accessToken
+ }
+ const refreshToken = params.get('refresh_token')?.trim() || ''
+ if (refreshToken) {
+ completion.refresh_token = refreshToken
+ }
+ const expiresIn = Number.parseInt(params.get('expires_in')?.trim() || '', 10)
+ if (Number.isFinite(expiresIn) && expiresIn > 0) {
+ completion.expires_in = expiresIn
+ }
+ const tokenType = params.get('token_type')?.trim() || ''
+ if (tokenType) {
+ completion.token_type = tokenType
+ }
+ return completion
+}
+
function sanitizeRedirectPath(path: string | null | undefined): string {
if (!path) return '/dashboard'
if (!path.startsWith('/')) return '/dashboard'
@@ -565,10 +591,18 @@ async function handleSubmitInvitation() {
isSubmitting.value = true
try {
- const tokenData = await completeOIDCOAuthRegistration(
- invitationCode.value.trim(),
- currentAdoptionDecision()
- )
+ const tokenData = legacyPendingOAuthToken.value
+ ? (
+ await apiClient.post('/auth/oauth/oidc/complete-registration', {
+ pending_oauth_token: legacyPendingOAuthToken.value,
+ invitation_code: invitationCode.value.trim(),
+ ...serializeAdoptionDecision(currentAdoptionDecision())
+ })
+ ).data
+ : await completeOIDCOAuthRegistration(
+ invitationCode.value.trim(),
+ currentAdoptionDecision()
+ )
persistOAuthTokenContext(tokenData)
await authStore.setToken(tokenData.access_token)
appStore.showSuccess(t('auth.loginSuccess'))
@@ -667,51 +701,72 @@ onMounted(async () => {
void loadProviderName()
const params = parseFragmentParams()
+ const legacyLogin = readLegacyFragmentLogin(params)
+ const legacyPendingToken = params.get('pending_oauth_token')?.trim() || ''
const error = params.get('error')
const errorDesc = params.get('error_description') || params.get('error_message') || ''
-
- if (error) {
- errorMessage.value = errorDesc || error
- appStore.showError(errorMessage.value)
- isProcessing.value = false
- return
- }
+ const redirect = sanitizeRedirectPath(
+ params.get('redirect') || (route.query.redirect as string | undefined) || '/dashboard'
+ )
try {
+ if (legacyLogin) {
+ persistOAuthTokenContext(legacyLogin)
+ await authStore.setToken(legacyLogin.access_token)
+ appStore.showSuccess(t('auth.loginSuccess'))
+ await router.replace(redirect)
+ return
+ }
+
+ if (error === 'invitation_required' && legacyPendingToken) {
+ legacyPendingOAuthToken.value = legacyPendingToken
+ redirectTo.value = redirect
+ needsInvitation.value = true
+ isProcessing.value = false
+ return
+ }
+
+ if (error) {
+ errorMessage.value = errorDesc || error
+ appStore.showError(errorMessage.value)
+ isProcessing.value = false
+ return
+ }
+
const completion = await exchangePendingOAuthCompletion() as PendingOidcCompletion
- const redirect = sanitizeRedirectPath(
+ const completionRedirect = sanitizeRedirectPath(
completion.redirect || (route.query.redirect as string | undefined) || '/dashboard'
)
applyAdoptionSuggestionState(completion)
- redirectTo.value = redirect
+ redirectTo.value = completionRedirect
if (completion.error === 'invitation_required') {
needsInvitation.value = true
isProcessing.value = false
- persistPendingAuthSession(redirect)
+ persistPendingAuthSession(completionRedirect)
return
}
if (applyTotpChallenge(completion)) {
- persistPendingAuthSession(redirect)
+ persistPendingAuthSession(completionRedirect)
return
}
applyPendingAccountAction(completion)
if (pendingAccountAction.value !== 'none') {
isProcessing.value = false
- persistPendingAuthSession(redirect)
+ persistPendingAuthSession(completionRedirect)
return
}
if (adoptionRequired.value && hasSuggestedProfile(completion)) {
needsAdoptionConfirmation.value = true
isProcessing.value = false
- persistPendingAuthSession(redirect)
+ persistPendingAuthSession(completionRedirect)
return
}
- await finalizeCompletion(completion, redirect)
+ await finalizeCompletion(completion, completionRedirect)
} catch (e: unknown) {
clearPendingAuthSession()
errorMessage.value = getRequestErrorMessage(e, t('auth.loginFailed'))
diff --git a/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts b/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts
index f612681a..0daf5d9a 100644
--- a/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts
+++ b/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts
@@ -86,6 +86,74 @@ describe('LinuxDoCallbackView', () => {
turnstile_enabled: false,
turnstile_site_key: ''
})
+ window.location.hash = ''
+ localStorage.clear()
+ })
+
+ it('accepts the legacy fragment token success callback without pending-session exchange', async () => {
+ window.location.hash =
+ '#access_token=legacy-access-token&refresh_token=legacy-refresh-token&expires_in=3600&token_type=Bearer&redirect=%2Flegacy-dashboard'
+ setToken.mockResolvedValue({})
+
+ mount(LinuxDoCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false
+ }
+ }
+ })
+
+ await flushPromises()
+
+ expect(exchangePendingOAuthCompletion).not.toHaveBeenCalled()
+ expect(setToken).toHaveBeenCalledWith('legacy-access-token')
+ expect(localStorage.getItem('refresh_token')).toBe('legacy-refresh-token')
+ expect(localStorage.getItem('token_expires_at')).not.toBeNull()
+ expect(showSuccess).toHaveBeenCalledWith('auth.loginSuccess')
+ expect(replace).toHaveBeenCalledWith('/legacy-dashboard')
+ })
+
+ it('accepts the legacy pending oauth invitation fragment without pending-session exchange', async () => {
+ window.location.hash = '#error=invitation_required&pending_oauth_token=legacy-pending-token&redirect=%2Flegacy-invite'
+ apiClientPost.mockResolvedValue({
+ data: {
+ access_token: 'legacy-access-token',
+ refresh_token: 'legacy-refresh-token',
+ expires_in: 3600,
+ token_type: 'Bearer'
+ }
+ })
+ setToken.mockResolvedValue({})
+
+ const wrapper = mount(LinuxDoCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false
+ }
+ }
+ })
+
+ await flushPromises()
+
+ expect(exchangePendingOAuthCompletion).not.toHaveBeenCalled()
+ await wrapper.find('input[type="text"]').setValue('invite-code')
+ await wrapper.find('button').trigger('click')
+ await flushPromises()
+
+ expect(apiClientPost).toHaveBeenCalledWith('/auth/oauth/linuxdo/complete-registration', {
+ adopt_display_name: true,
+ adopt_avatar: true,
+ pending_oauth_token: 'legacy-pending-token',
+ invitation_code: 'invite-code'
+ })
+ expect(setToken).toHaveBeenCalledWith('legacy-access-token')
+ expect(replace).toHaveBeenCalledWith('/legacy-invite')
})
it('does not send adoption decisions during the initial exchange', async () => {
diff --git a/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts b/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts
index 0edcb931..18128f17 100644
--- a/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts
+++ b/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts
@@ -92,6 +92,74 @@ describe('OidcCallbackView', () => {
turnstile_enabled: false,
turnstile_site_key: ''
})
+ window.location.hash = ''
+ localStorage.clear()
+ })
+
+ it('accepts the legacy fragment token success callback without pending-session exchange', async () => {
+ window.location.hash =
+ '#access_token=legacy-access-token&refresh_token=legacy-refresh-token&expires_in=3600&token_type=Bearer&redirect=%2Flegacy-dashboard'
+ setToken.mockResolvedValue({})
+
+ mount(OidcCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false
+ }
+ }
+ })
+
+ await flushPromises()
+
+ expect(exchangePendingOAuthCompletion).not.toHaveBeenCalled()
+ expect(setToken).toHaveBeenCalledWith('legacy-access-token')
+ expect(localStorage.getItem('refresh_token')).toBe('legacy-refresh-token')
+ expect(localStorage.getItem('token_expires_at')).not.toBeNull()
+ expect(showSuccess).toHaveBeenCalledWith('auth.loginSuccess')
+ expect(replace).toHaveBeenCalledWith('/legacy-dashboard')
+ })
+
+ it('accepts the legacy pending oauth invitation fragment without pending-session exchange', async () => {
+ window.location.hash = '#error=invitation_required&pending_oauth_token=legacy-pending-token&redirect=%2Flegacy-invite'
+ apiClientPost.mockResolvedValue({
+ data: {
+ access_token: 'legacy-access-token',
+ refresh_token: 'legacy-refresh-token',
+ expires_in: 3600,
+ token_type: 'Bearer'
+ }
+ })
+ setToken.mockResolvedValue({})
+
+ const wrapper = mount(OidcCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false
+ }
+ }
+ })
+
+ await flushPromises()
+
+ expect(exchangePendingOAuthCompletion).not.toHaveBeenCalled()
+ await wrapper.find('input[type="text"]').setValue('invite-code')
+ await wrapper.find('button').trigger('click')
+ await flushPromises()
+
+ expect(apiClientPost).toHaveBeenCalledWith('/auth/oauth/oidc/complete-registration', {
+ adopt_display_name: true,
+ adopt_avatar: true,
+ pending_oauth_token: 'legacy-pending-token',
+ invitation_code: 'invite-code'
+ })
+ expect(setToken).toHaveBeenCalledWith('legacy-access-token')
+ expect(replace).toHaveBeenCalledWith('/legacy-invite')
})
it('does not send adoption decisions during the initial exchange', async () => {
--
GitLab
From 375aefa2097f3c4ce4f7497365fd76306d409d45 Mon Sep 17 00:00:00 2001
From: erio
Date: Tue, 21 Apr 2026 11:31:54 +0800
Subject: [PATCH 128/261] refactor(channels): centralize BillingModelSource
normalization and exhaustive enum maps
- service: add normalizeBillingModelSource helper, apply in Create/GetByID/Update/List/ListAvailable outputs
- handler: drop channelToResponse fallback now that service owns the default; add passthrough test
- frontend: replace ternary status/billing-source lookups with Record maps so new union members fail the build
- chip/table: drop local type aliases, reuse UserSupportedModel/UserPricingInterval directly
- tests: assert short-circuit on ListAll error, wrap-prefix preservation, and Name-based default lookup
---
.../internal/handler/admin/channel_handler.go | 3 --
.../handler/admin/channel_handler_test.go | 18 ++++++-
backend/internal/service/channel_available.go | 12 ++---
.../service/channel_available_test.go | 25 +++++++---
backend/internal/service/channel_service.go | 45 ++++++++++++++---
.../channels/AvailableChannelsTable.vue | 9 ++--
.../channels/SupportedModelChip.vue | 14 ++----
.../src/views/admin/AvailableChannelsView.vue | 49 ++++++++++++++-----
8 files changed, 122 insertions(+), 53 deletions(-)
diff --git a/backend/internal/handler/admin/channel_handler.go b/backend/internal/handler/admin/channel_handler.go
index 9151d018..950e6e72 100644
--- a/backend/internal/handler/admin/channel_handler.go
+++ b/backend/internal/handler/admin/channel_handler.go
@@ -158,9 +158,6 @@ func channelToResponse(ch *service.Channel) *channelResponse {
UpdatedAt: ch.UpdatedAt.Format("2006-01-02T15:04:05Z"),
}
resp.BillingModelSource = ch.BillingModelSource
- if resp.BillingModelSource == "" {
- resp.BillingModelSource = service.BillingModelSourceChannelMapped
- }
if resp.GroupIDs == nil {
resp.GroupIDs = []int64{}
}
diff --git a/backend/internal/handler/admin/channel_handler_test.go b/backend/internal/handler/admin/channel_handler_test.go
index f218cce4..12cd4bdd 100644
--- a/backend/internal/handler/admin/channel_handler_test.go
+++ b/backend/internal/handler/admin/channel_handler_test.go
@@ -91,7 +91,7 @@ func TestChannelToResponse_EmptyDefaults(t *testing.T) {
ch := &service.Channel{
ID: 1,
Name: "ch",
- BillingModelSource: "",
+ BillingModelSource: service.BillingModelSourceChannelMapped,
CreatedAt: now,
UpdatedAt: now,
GroupIDs: nil,
@@ -105,6 +105,9 @@ func TestChannelToResponse_EmptyDefaults(t *testing.T) {
},
}
+ // handler 层 channelToResponse 现在是纯透传:BillingModelSource 的空值兜底
+ // 已下放到 service 层(Create/GetByID/List/Update/ListAvailable 出口统一处理),
+ // 因此这里构造 fixture 时直接传入归一化后的值。
resp := channelToResponse(ch)
require.Equal(t, "channel_mapped", resp.BillingModelSource)
require.NotNil(t, resp.GroupIDs)
@@ -117,6 +120,19 @@ func TestChannelToResponse_EmptyDefaults(t *testing.T) {
require.Equal(t, "token", resp.ModelPricing[0].BillingMode)
}
+func TestChannelToResponse_BillingModelSourcePassthrough(t *testing.T) {
+ // handler 不再兜底 BillingModelSource:空值应原样透传(由 service 层负责默认回填)。
+ ch := &service.Channel{
+ ID: 1,
+ Name: "ch",
+ BillingModelSource: "",
+ CreatedAt: time.Now(),
+ UpdatedAt: time.Now(),
+ }
+ resp := channelToResponse(ch)
+ require.Equal(t, "", resp.BillingModelSource, "handler 应纯透传,默认值由 service.normalizeBillingModelSource 负责")
+}
+
func TestChannelToResponse_NilModels(t *testing.T) {
now := time.Now()
ch := &service.Channel{
diff --git a/backend/internal/service/channel_available.go b/backend/internal/service/channel_available.go
index 62406cd0..a162d81d 100644
--- a/backend/internal/service/channel_available.go
+++ b/backend/internal/service/channel_available.go
@@ -32,6 +32,9 @@ type AvailableChannel struct {
// 支持模型通过 (*Channel).SupportedModels() 计算得到(见 channel.go)。
// 关联分组信息通过 groupRepo.ListActive 查询后按 ID 映射;渠道 GroupIDs 中未在活跃列表中
// 的分组(已停用或删除)会被忽略。
+//
+// 前置条件:s.groupRepo 必须非 nil(由 wire DI 保证)。直接 nil-deref 用于 fail-fast,
+// 避免静默掩盖注入缺失。
func (s *ChannelService) ListAvailable(ctx context.Context) ([]AvailableChannel, error) {
channels, err := s.repo.ListAll(ctx)
if err != nil {
@@ -61,19 +64,16 @@ func (s *ChannelService) ListAvailable(ctx context.Context) ([]AvailableChannel,
groups = append(groups, ref)
}
}
- sort.Slice(groups, func(i, j int) bool { return groups[i].Name < groups[j].Name })
+ sort.SliceStable(groups, func(i, j int) bool { return groups[i].Name < groups[j].Name })
- billingSource := ch.BillingModelSource
- if billingSource == "" {
- billingSource = BillingModelSourceChannelMapped
- }
+ normalizeBillingModelSource(ch)
out = append(out, AvailableChannel{
ID: ch.ID,
Name: ch.Name,
Description: ch.Description,
Status: ch.Status,
- BillingModelSource: billingSource,
+ BillingModelSource: ch.BillingModelSource,
RestrictModels: ch.RestrictModels,
Groups: groups,
SupportedModels: ch.SupportedModels(),
diff --git a/backend/internal/service/channel_available_test.go b/backend/internal/service/channel_available_test.go
index 5da5e6e1..86bb4bb6 100644
--- a/backend/internal/service/channel_available_test.go
+++ b/backend/internal/service/channel_available_test.go
@@ -14,12 +14,15 @@ import (
// stubGroupRepoForAvailable 是 ListAvailable 测试用的 GroupRepository stub,
// 仅实现 ListActive;其他方法对本测试无关,返回零值即可。
// listActiveErr 非 nil 时,ListActive 返回该错误用于错误传播测试。
+// listActiveCalls 记录调用次数,用于断言「失败短路时不再访问 groupRepo」等行为。
type stubGroupRepoForAvailable struct {
- activeGroups []Group
- listActiveErr error
+ activeGroups []Group
+ listActiveErr error
+ listActiveCalls int
}
func (s *stubGroupRepoForAvailable) ListActive(ctx context.Context) ([]Group, error) {
+ s.listActiveCalls++
if s.listActiveErr != nil {
return nil, s.listActiveErr
}
@@ -125,15 +128,18 @@ func TestListAvailable_SortedByName(t *testing.T) {
}
func TestListAvailable_ListAllErrorPropagates(t *testing.T) {
- // ListAll 返回错误时 ListAvailable 应直接返回包装后的错误,不再访问 groupRepo。
+ // ListAll 返回错误时 ListAvailable 应直接返回包装后的错误,且不再访问 groupRepo(短路)。
sentinel := errors.New("list-all-boom")
repo := &mockChannelRepository{
listAllFn: func(ctx context.Context) ([]Channel, error) { return nil, sentinel },
}
- svc := NewChannelService(repo, &stubGroupRepoForAvailable{}, nil)
+ groupRepo := &stubGroupRepoForAvailable{}
+ svc := NewChannelService(repo, groupRepo, nil)
out, err := svc.ListAvailable(context.Background())
require.Nil(t, out)
require.ErrorIs(t, err, sentinel)
+ require.Contains(t, err.Error(), "list channels", "wrap 前缀缺失,可能 %w 被改为 %v")
+ require.Equal(t, 0, groupRepo.listActiveCalls, "ListAll 失败后不应再调用 groupRepo.ListActive")
}
func TestListAvailable_ListActiveErrorPropagates(t *testing.T) {
@@ -146,6 +152,7 @@ func TestListAvailable_ListActiveErrorPropagates(t *testing.T) {
out, err := svc.ListAvailable(context.Background())
require.Nil(t, out)
require.ErrorIs(t, err, sentinel)
+ require.Contains(t, err.Error(), "list active groups", "wrap 前缀缺失,可能 %w 被改为 %v")
}
func TestListAvailable_DefaultsEmptyBillingModelSource(t *testing.T) {
@@ -159,6 +166,12 @@ func TestListAvailable_DefaultsEmptyBillingModelSource(t *testing.T) {
out, err := svc.ListAvailable(context.Background())
require.NoError(t, err)
require.Len(t, out, 2)
- require.Equal(t, BillingModelSourceChannelMapped, out[0].BillingModelSource)
- require.Equal(t, BillingModelSourceUpstream, out[1].BillingModelSource)
+
+ // 按 Name 查找,避免依赖排序副作用。
+ byName := make(map[string]string, len(out))
+ for _, ch := range out {
+ byName[ch.Name] = ch.BillingModelSource
+ }
+ require.Equal(t, BillingModelSourceChannelMapped, byName["empty"])
+ require.Equal(t, BillingModelSourceUpstream, byName["explicit"])
}
diff --git a/backend/internal/service/channel_service.go b/backend/internal/service/channel_service.go
index 250df07b..4f22e205 100644
--- a/backend/internal/service/channel_service.go
+++ b/backend/internal/service/channel_service.go
@@ -686,9 +686,7 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
ApplyPricingToAccountStats: input.ApplyPricingToAccountStats,
AccountStatsPricingRules: input.AccountStatsPricingRules,
}
- if channel.BillingModelSource == "" {
- channel.BillingModelSource = BillingModelSourceChannelMapped
- }
+ normalizeBillingModelSource(channel)
if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil {
return nil, err
@@ -704,12 +702,31 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
}
s.invalidateCache()
- return s.repo.GetByID(ctx, channel.ID)
+ created, err := s.repo.GetByID(ctx, channel.ID)
+ if err != nil {
+ return nil, err
+ }
+ normalizeBillingModelSource(created)
+ return created, nil
}
-// GetByID 获取渠道详情
+// GetByID 获取渠道详情。返回前统一把空 BillingModelSource 回填为 ChannelMapped,
+// 让所有 handler 无需重复处理历史空值。
func (s *ChannelService) GetByID(ctx context.Context, id int64) (*Channel, error) {
- return s.repo.GetByID(ctx, id)
+ ch, err := s.repo.GetByID(ctx, id)
+ if err != nil {
+ return nil, err
+ }
+ normalizeBillingModelSource(ch)
+ return ch, nil
+}
+
+// normalizeBillingModelSource 若 BillingModelSource 为空则回填默认值 ChannelMapped。
+// 统一在 service 层完成,避免 handler 响应层重复兜底。
+func normalizeBillingModelSource(ch *Channel) {
+ if ch != nil && ch.BillingModelSource == "" {
+ ch.BillingModelSource = BillingModelSourceChannelMapped
+ }
}
// Update 更新渠道
@@ -741,7 +758,12 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan
s.invalidateCache()
s.invalidateAuthCacheForGroups(ctx, oldGroupIDs, channel.GroupIDs)
- return s.repo.GetByID(ctx, id)
+ updated, err := s.repo.GetByID(ctx, id)
+ if err != nil {
+ return nil, err
+ }
+ normalizeBillingModelSource(updated)
+ return updated, nil
}
// applyUpdateInput 将更新请求的字段应用到渠道实体上。
@@ -859,7 +881,14 @@ func (s *ChannelService) Delete(ctx context.Context, id int64) error {
// List 获取渠道列表
func (s *ChannelService) List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]Channel, *pagination.PaginationResult, error) {
- return s.repo.List(ctx, params, status, search)
+ channels, res, err := s.repo.List(ctx, params, status, search)
+ if err != nil {
+ return nil, nil, err
+ }
+ for i := range channels {
+ normalizeBillingModelSource(&channels[i])
+ }
+ return channels, res, nil
}
// modelEntry 表示一个模型模式条目(用于冲突检测)
diff --git a/frontend/src/components/channels/AvailableChannelsTable.vue b/frontend/src/components/channels/AvailableChannelsTable.vue
index e9011ec5..0bd19518 100644
--- a/frontend/src/components/channels/AvailableChannelsTable.vue
+++ b/frontend/src/components/channels/AvailableChannelsTable.vue
@@ -61,7 +61,7 @@ import { computed, useSlots } from 'vue'
import DataTable from '@/components/common/DataTable.vue'
import Icon from '@/components/icons/Icon.vue'
import SupportedModelChip from './SupportedModelChip.vue'
-import type { UserSupportedModelPricing } from '@/api/channels'
+import type { UserSupportedModel } from '@/api/channels'
interface GroupRef {
id: number
@@ -73,11 +73,8 @@ interface Row {
name: string
description?: string
groups: GroupRef[]
- supported_models: Array<{
- name: string
- platform: string
- pricing: UserSupportedModelPricing | null
- }>
+ // 复用 user 侧最小 DTO;admin 侧 SupportedModel 结构上是其超集,可直接传入。
+ supported_models: UserSupportedModel[]
[key: string]: unknown
}
diff --git a/frontend/src/components/channels/SupportedModelChip.vue b/frontend/src/components/channels/SupportedModelChip.vue
index f3e5549b..600e3ef5 100644
--- a/frontend/src/components/channels/SupportedModelChip.vue
+++ b/frontend/src/components/channels/SupportedModelChip.vue
@@ -127,19 +127,13 @@ import {
BILLING_MODE_IMAGE,
type BillingMode
} from '@/constants/channel'
+// 复用 api/channels.ts 的用户侧最小形态 DTO。
+// admin 侧 ChannelModelPricing 字段更多,但结构上是用户 DTO 的超集,admin 视图传入可直接通过结构化子类型检查。
import type { UserPricingInterval, UserSupportedModel } from '@/api/channels'
-/**
- * 复用 api/channels.ts 的用户侧最小形态 DTO。
- * admin 侧 ChannelModelPricing 字段更多,但结构上是用户 DTO 的超集,
- * 因此 admin 视图传入时 TypeScript 结构化子类型会直接通过。
- */
-type PricingInterval = UserPricingInterval
-type SupportedModelLike = UserSupportedModel
-
const props = withDefaults(
defineProps<{
- model: SupportedModelLike
+ model: UserSupportedModel
/** i18n 前缀:管理端传 `admin.availableChannels.pricing`,用户端传 `availableChannels.pricing`。 */
pricingKeyPrefix?: string
noPricingLabel?: string
@@ -180,7 +174,7 @@ function formatRange(min: number, max: number | null): string {
return `(${min}, ${maxLabel}]`
}
-function formatInterval(iv: PricingInterval, mode: BillingMode): string {
+function formatInterval(iv: UserPricingInterval, mode: BillingMode): string {
if (mode === BILLING_MODE_PER_REQUEST || mode === BILLING_MODE_IMAGE) {
return formatScaled(iv.per_request_price, 1)
}
diff --git a/frontend/src/views/admin/AvailableChannelsView.vue b/frontend/src/views/admin/AvailableChannelsView.vue
index c7c27154..74e85618 100644
--- a/frontend/src/views/admin/AvailableChannelsView.vue
+++ b/frontend/src/views/admin/AvailableChannelsView.vue
@@ -46,20 +46,16 @@
- {{ statusLabel(row.status) }}
+ {{ statusStyles[row.status as ChannelStatus].label }}
- {{ t(`admin.availableChannels.billingSource.${row.billing_model_source}`) }}
+ {{ billingSourceLabels[row.billing_model_source as BillingModelSource] }}
@@ -78,7 +74,15 @@ import AvailableChannelsTable from '@/components/channels/AvailableChannelsTable
import channelsAPI, { type AvailableChannel } from '@/api/admin/channels'
import { useAppStore } from '@/stores/app'
import { extractApiErrorMessage } from '@/utils/apiError'
-import { CHANNEL_STATUS_ACTIVE, type ChannelStatus } from '@/constants/channel'
+import {
+ CHANNEL_STATUS_ACTIVE,
+ CHANNEL_STATUS_DISABLED,
+ BILLING_MODEL_SOURCE_REQUESTED,
+ BILLING_MODEL_SOURCE_UPSTREAM,
+ BILLING_MODEL_SOURCE_CHANNEL_MAPPED,
+ type ChannelStatus,
+ type BillingModelSource
+} from '@/constants/channel'
const { t } = useI18n()
const appStore = useAppStore()
@@ -95,11 +99,30 @@ const columns = computed(() => [
{ key: 'supported_models', label: t('admin.availableChannels.columns.supportedModels') }
])
-function statusLabel(status: ChannelStatus): string {
- return status === CHANNEL_STATUS_ACTIVE
- ? t('admin.availableChannels.statusActive')
- : t('admin.availableChannels.statusDisabled')
-}
+/**
+ * 显示样式:i18n label + Tailwind class,按 ChannelStatus 完整穷举。
+ * 用 Record 强制未来新增状态时 TS 编译失败,避免遗漏分支。
+ */
+const statusStyles = computed>(() => ({
+ [CHANNEL_STATUS_ACTIVE]: {
+ label: t('admin.availableChannels.statusActive'),
+ cls: 'bg-green-100 text-green-800 dark:bg-green-900/30 dark:text-green-400'
+ },
+ [CHANNEL_STATUS_DISABLED]: {
+ label: t('admin.availableChannels.statusDisabled'),
+ cls: 'bg-gray-100 text-gray-600 dark:bg-dark-700 dark:text-gray-400'
+ }
+}))
+
+/**
+ * BillingModelSource 显式映射:避免将后端 snake_case 字面量直接拼成 i18n key,
+ * 同时在 BillingModelSource 扩展时 TS 编译失败以暴露遗漏。
+ */
+const billingSourceLabels = computed>(() => ({
+ [BILLING_MODEL_SOURCE_REQUESTED]: t('admin.availableChannels.billingSource.requested'),
+ [BILLING_MODEL_SOURCE_UPSTREAM]: t('admin.availableChannels.billingSource.upstream'),
+ [BILLING_MODEL_SOURCE_CHANNEL_MAPPED]: t('admin.availableChannels.billingSource.channel_mapped')
+}))
const filteredChannels = computed(() => {
const q = searchQuery.value.trim().toLowerCase()
--
GitLab
From 9742796ee727e678fabf1e4028640a764cbac895 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 11:41:02 +0800
Subject: [PATCH 129/261] fix: retire public payment verify and backfill trade
no
---
backend/internal/handler/payment_handler.go | 34 +--
.../handler/payment_handler_resume_test.go | 53 ++++
backend/internal/server/routes/payment.go | 5 +-
.../service/payment_order_lifecycle.go | 14 +-
.../service/payment_order_lifecycle_test.go | 251 ++++++++++++++++++
frontend/src/api/__tests__/payment.spec.ts | 36 +++
frontend/src/api/payment.ts | 5 -
.../user/__tests__/PaymentResultView.spec.ts | 8 -
8 files changed, 369 insertions(+), 37 deletions(-)
create mode 100644 backend/internal/service/payment_order_lifecycle_test.go
create mode 100644 frontend/src/api/__tests__/payment.spec.ts
diff --git a/backend/internal/handler/payment_handler.go b/backend/internal/handler/payment_handler.go
index 5fd6b43e..273aea73 100644
--- a/backend/internal/handler/payment_handler.go
+++ b/backend/internal/handler/payment_handler.go
@@ -2,6 +2,7 @@ package handler
import (
"fmt"
+ "net/http"
"strconv"
"strings"
@@ -459,29 +460,20 @@ type PublicOrderResult struct {
Status string `json:"status"`
}
-// VerifyOrderPublic verifies payment status without requiring authentication.
-// Returns limited order info (no user details) to prevent information leakage.
+var errPaymentPublicOrderVerifyRemoved = infraerrors.New(
+ http.StatusGone,
+ "PAYMENT_PUBLIC_ORDER_VERIFY_REMOVED",
+ "public payment order verification by out_trade_no has been removed; use resume_token recovery instead",
+).WithMetadata(map[string]string{
+ "replacement_endpoint": "/api/v1/payment/public/orders/resolve",
+ "replacement_field": "resume_token",
+})
+
+// VerifyOrderPublic is kept as a compatibility shim for the removed anonymous
+// out_trade_no lookup endpoint and always returns HTTP 410 Gone.
// POST /api/v1/payment/public/orders/verify
func (h *PaymentHandler) VerifyOrderPublic(c *gin.Context) {
- var req VerifyOrderRequest
- if err := c.ShouldBindJSON(&req); err != nil {
- response.BadRequest(c, "Invalid request: "+err.Error())
- return
- }
- order, err := h.paymentService.VerifyOrderPublic(c.Request.Context(), req.OutTradeNo)
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
- response.Success(c, PublicOrderResult{
- ID: order.ID,
- OutTradeNo: order.OutTradeNo,
- Amount: order.Amount,
- PayAmount: order.PayAmount,
- PaymentType: order.PaymentType,
- OrderType: order.OrderType,
- Status: order.Status,
- })
+ response.ErrorFrom(c, errPaymentPublicOrderVerifyRemoved)
}
// ResolveOrderPublicByResumeToken resolves a payment order from a signed resume token.
diff --git a/backend/internal/handler/payment_handler_resume_test.go b/backend/internal/handler/payment_handler_resume_test.go
index 323f7292..28da15d9 100644
--- a/backend/internal/handler/payment_handler_resume_test.go
+++ b/backend/internal/handler/payment_handler_resume_test.go
@@ -3,10 +3,24 @@
package handler
import (
+ "bytes"
+ "database/sql"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
"testing"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/internal/payment"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
)
func TestApplyWeChatPaymentResumeClaims(t *testing.T) {
@@ -59,3 +73,42 @@ func TestApplyWeChatPaymentResumeClaimsRejectsPaymentTypeMismatch(t *testing.T)
t.Fatal("applyWeChatPaymentResumeClaims should reject mismatched payment types")
}
}
+
+func TestVerifyOrderPublicReturnsGone(t *testing.T) {
+ t.Parallel()
+
+ gin.SetMode(gin.TestMode)
+
+ db, err := sql.Open("sqlite", "file:payment_handler_public_verify?mode=memory&cache=shared")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+ t.Cleanup(func() { _ = client.Close() })
+
+ paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil)
+ h := NewPaymentHandler(paymentSvc, nil, nil)
+
+ recorder := httptest.NewRecorder()
+ ctx, _ := gin.CreateTestContext(recorder)
+ ctx.Request = httptest.NewRequest(
+ http.MethodPost,
+ "/api/v1/payment/public/orders/verify",
+ bytes.NewBufferString(`{"out_trade_no":"legacy-order-no"}`),
+ )
+ ctx.Request.Header.Set("Content-Type", "application/json")
+
+ h.VerifyOrderPublic(ctx)
+
+ require.Equal(t, http.StatusGone, recorder.Code)
+
+ var resp response.Response
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, http.StatusGone, resp.Code)
+ require.Equal(t, "PAYMENT_PUBLIC_ORDER_VERIFY_REMOVED", resp.Reason)
+ require.Contains(t, resp.Message, "removed")
+}
diff --git a/backend/internal/server/routes/payment.go b/backend/internal/server/routes/payment.go
index dff14a70..ec340d94 100644
--- a/backend/internal/server/routes/payment.go
+++ b/backend/internal/server/routes/payment.go
@@ -44,8 +44,9 @@ func RegisterPaymentRoutes(
}
// --- Public payment endpoints (no auth) ---
- // Payment result page needs to verify order status without login
- // (user session may have expired during provider redirect).
+ // Signed resume-token recovery is the supported public lookup path.
+ // The legacy anonymous out_trade_no verify endpoint is kept only as a
+ // compatibility shim that returns HTTP 410 Gone.
public := v1.Group("/payment/public")
{
public.POST("/orders/verify", paymentHandler.VerifyOrderPublic)
diff --git a/backend/internal/service/payment_order_lifecycle.go b/backend/internal/service/payment_order_lifecycle.go
index 50a56ad0..1564c36d 100644
--- a/backend/internal/service/payment_order_lifecycle.go
+++ b/backend/internal/service/payment_order_lifecycle.go
@@ -151,7 +151,19 @@ func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) s
return ""
}
if resp.Status == payment.ProviderStatusPaid {
- if err := s.HandlePaymentNotification(ctx, &payment.PaymentNotification{TradeNo: o.PaymentTradeNo, OrderID: o.OutTradeNo, Amount: resp.Amount, Status: payment.ProviderStatusSuccess}, prov.ProviderKey()); err != nil {
+ notificationTradeNo := o.PaymentTradeNo
+ if upstreamTradeNo := resp.TradeNo; upstreamTradeNo != "" && upstreamTradeNo != notificationTradeNo {
+ if _, updateErr := s.entClient.PaymentOrder.Update().
+ Where(paymentorder.IDEQ(o.ID)).
+ SetPaymentTradeNo(upstreamTradeNo).
+ Save(ctx); updateErr != nil {
+ slog.Error("persist upstream trade no during checkPaid failed", "orderID", o.ID, "tradeNo", upstreamTradeNo, "error", updateErr)
+ } else {
+ o.PaymentTradeNo = upstreamTradeNo
+ }
+ notificationTradeNo = upstreamTradeNo
+ }
+ if err := s.HandlePaymentNotification(ctx, &payment.PaymentNotification{TradeNo: notificationTradeNo, OrderID: o.OutTradeNo, Amount: resp.Amount, Status: payment.ProviderStatusSuccess}, prov.ProviderKey()); err != nil {
slog.Error("fulfillment failed during checkPaid", "orderID", o.ID, "error", err)
// Still return already_paid — order was paid, fulfillment can be retried
}
diff --git a/backend/internal/service/payment_order_lifecycle_test.go b/backend/internal/service/payment_order_lifecycle_test.go
new file mode 100644
index 00000000..3d4773a4
--- /dev/null
+++ b/backend/internal/service/payment_order_lifecycle_test.go
@@ -0,0 +1,251 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "database/sql"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/stretchr/testify/require"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
+)
+
+type paymentOrderLifecycleQueryProvider struct {
+ lastQueryTradeNo string
+ resp *payment.QueryOrderResponse
+}
+
+type paymentOrderLifecycleRedeemRepo struct {
+ codesByCode map[string]*RedeemCode
+ useCalls []struct {
+ id int64
+ userID int64
+ }
+}
+
+func (p *paymentOrderLifecycleQueryProvider) Name() string {
+ return "payment-order-lifecycle-query-provider"
+}
+
+func (p *paymentOrderLifecycleQueryProvider) ProviderKey() string { return payment.TypeAlipay }
+
+func (p *paymentOrderLifecycleQueryProvider) SupportedTypes() []payment.PaymentType {
+ return []payment.PaymentType{payment.TypeAlipay}
+}
+
+func (p *paymentOrderLifecycleQueryProvider) CreatePayment(context.Context, payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
+ panic("unexpected call")
+}
+
+func (p *paymentOrderLifecycleQueryProvider) QueryOrder(_ context.Context, tradeNo string) (*payment.QueryOrderResponse, error) {
+ p.lastQueryTradeNo = tradeNo
+ return p.resp, nil
+}
+
+func (p *paymentOrderLifecycleQueryProvider) VerifyNotification(context.Context, string, map[string]string) (*payment.PaymentNotification, error) {
+ panic("unexpected call")
+}
+
+func (p *paymentOrderLifecycleQueryProvider) Refund(context.Context, payment.RefundRequest) (*payment.RefundResponse, error) {
+ panic("unexpected call")
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) Create(context.Context, *RedeemCode) error {
+ panic("unexpected call")
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) CreateBatch(context.Context, []RedeemCode) error {
+ panic("unexpected call")
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) GetByID(_ context.Context, id int64) (*RedeemCode, error) {
+ for _, code := range r.codesByCode {
+ if code.ID != id {
+ continue
+ }
+ cloned := *code
+ return &cloned, nil
+ }
+ return nil, ErrRedeemCodeNotFound
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) GetByCode(_ context.Context, code string) (*RedeemCode, error) {
+ redeemCode, ok := r.codesByCode[code]
+ if !ok {
+ return nil, ErrRedeemCodeNotFound
+ }
+ cloned := *redeemCode
+ return &cloned, nil
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) Update(context.Context, *RedeemCode) error {
+ panic("unexpected call")
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) Delete(context.Context, int64) error {
+ panic("unexpected call")
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) Use(_ context.Context, id, userID int64) error {
+ for code, redeemCode := range r.codesByCode {
+ if redeemCode.ID != id {
+ continue
+ }
+ now := time.Now().UTC()
+ redeemCode.Status = StatusUsed
+ redeemCode.UsedBy = &userID
+ redeemCode.UsedAt = &now
+ r.codesByCode[code] = redeemCode
+ r.useCalls = append(r.useCalls, struct {
+ id int64
+ userID int64
+ }{id: id, userID: userID})
+ return nil
+ }
+ return ErrRedeemCodeNotFound
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) List(context.Context, pagination.PaginationParams) ([]RedeemCode, *pagination.PaginationResult, error) {
+ panic("unexpected call")
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string) ([]RedeemCode, *pagination.PaginationResult, error) {
+ panic("unexpected call")
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) ListByUser(context.Context, int64, int) ([]RedeemCode, error) {
+ panic("unexpected call")
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) ListByUserPaginated(context.Context, int64, pagination.PaginationParams, string) ([]RedeemCode, *pagination.PaginationResult, error) {
+ panic("unexpected call")
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) SumPositiveBalanceByUser(context.Context, int64) (float64, error) {
+ panic("unexpected call")
+}
+
+func TestVerifyOrderByOutTradeNoBackfillsTradeNoFromPaidQuery(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentOrderLifecycleTestClient(t)
+
+ user, err := client.User.Create().
+ SetEmail("checkpaid@example.com").
+ SetPasswordHash("hash").
+ SetUsername("checkpaid-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(88).
+ SetPayAmount(88).
+ SetFeeRate(0).
+ SetRechargeCode("CHECKPAID-UPSTREAM-TRADE-NO").
+ SetOutTradeNo("sub2_checkpaid_trade_no_missing").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusPending).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ Save(ctx)
+ require.NoError(t, err)
+
+ userRepo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: user.ID,
+ Email: user.Email,
+ Username: user.Username,
+ Balance: 0,
+ },
+ }
+ userRepo.updateBalanceFn = func(ctx context.Context, id int64, amount float64) error {
+ require.Equal(t, user.ID, id)
+ if userRepo.getByIDUser != nil {
+ userRepo.getByIDUser.Balance += amount
+ }
+ return nil
+ }
+ redeemRepo := &paymentOrderLifecycleRedeemRepo{
+ codesByCode: map[string]*RedeemCode{
+ order.RechargeCode: {
+ ID: 1,
+ Code: order.RechargeCode,
+ Type: RedeemTypeBalance,
+ Value: order.Amount,
+ Status: StatusUnused,
+ },
+ },
+ }
+ redeemService := NewRedeemService(
+ redeemRepo,
+ userRepo,
+ nil,
+ nil,
+ nil,
+ client,
+ nil,
+ )
+ registry := payment.NewRegistry()
+ provider := &paymentOrderLifecycleQueryProvider{
+ resp: &payment.QueryOrderResponse{
+ TradeNo: "upstream-trade-123",
+ Status: payment.ProviderStatusPaid,
+ Amount: 88,
+ },
+ }
+ registry.Register(provider)
+
+ svc := &PaymentService{
+ entClient: client,
+ registry: registry,
+ redeemService: redeemService,
+ userRepo: userRepo,
+ providersLoaded: true,
+ }
+
+ got, err := svc.VerifyOrderByOutTradeNo(ctx, order.OutTradeNo, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, order.OutTradeNo, provider.lastQueryTradeNo)
+ require.Equal(t, OrderStatusCompleted, got.Status)
+ require.Equal(t, "upstream-trade-123", got.PaymentTradeNo)
+
+ reloaded, err := client.PaymentOrder.Get(ctx, order.ID)
+ require.NoError(t, err)
+ require.Equal(t, OrderStatusCompleted, reloaded.Status)
+ require.Equal(t, "upstream-trade-123", reloaded.PaymentTradeNo)
+
+ require.Equal(t, 88.0, userRepo.getByIDUser.Balance)
+ require.Len(t, redeemRepo.useCalls, 1)
+ require.Equal(t, int64(1), redeemRepo.useCalls[0].id)
+ require.Equal(t, user.ID, redeemRepo.useCalls[0].userID)
+}
+
+func newPaymentOrderLifecycleTestClient(t *testing.T) *dbent.Client {
+ t.Helper()
+
+ db, err := sql.Open("sqlite", "file:payment_order_lifecycle?mode=memory&cache=shared&_fk=1")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+ t.Cleanup(func() { _ = client.Close() })
+ return client
+}
diff --git a/frontend/src/api/__tests__/payment.spec.ts b/frontend/src/api/__tests__/payment.spec.ts
new file mode 100644
index 00000000..3006484e
--- /dev/null
+++ b/frontend/src/api/__tests__/payment.spec.ts
@@ -0,0 +1,36 @@
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+
+const { get, post } = vi.hoisted(() => ({
+ get: vi.fn(),
+ post: vi.fn(),
+}))
+
+vi.mock('@/api/client', () => ({
+ apiClient: {
+ get,
+ post,
+ },
+}))
+
+import { paymentAPI } from '@/api/payment'
+
+describe('payment api', () => {
+ beforeEach(() => {
+ get.mockReset()
+ post.mockReset()
+ get.mockResolvedValue({ data: {} })
+ post.mockResolvedValue({ data: {} })
+ })
+
+ it('does not expose anonymous public out_trade_no verification', () => {
+ expect(Object.prototype.hasOwnProperty.call(paymentAPI, 'verifyOrderPublic')).toBe(false)
+ })
+
+ it('keeps signed public resume-token resolve endpoint', async () => {
+ await paymentAPI.resolveOrderPublicByResumeToken('resume-token-123')
+
+ expect(post).toHaveBeenCalledWith('/payment/public/orders/resolve', {
+ resume_token: 'resume-token-123',
+ })
+ })
+})
diff --git a/frontend/src/api/payment.ts b/frontend/src/api/payment.ts
index 91b16866..e866e184 100644
--- a/frontend/src/api/payment.ts
+++ b/frontend/src/api/payment.ts
@@ -67,11 +67,6 @@ export const paymentAPI = {
return apiClient.post('/payment/orders/verify', { out_trade_no: outTradeNo })
},
- /** Verify order payment status without auth (public endpoint for result page) */
- verifyOrderPublic(outTradeNo: string) {
- return apiClient.post('/payment/public/orders/verify', { out_trade_no: outTradeNo })
- },
-
/** Resolve an order from a signed resume token without auth */
resolveOrderPublicByResumeToken(resumeToken: string) {
return apiClient.post('/payment/public/orders/resolve', { resume_token: resumeToken })
diff --git a/frontend/src/views/user/__tests__/PaymentResultView.spec.ts b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts
index 56e64793..34ced07a 100644
--- a/frontend/src/views/user/__tests__/PaymentResultView.spec.ts
+++ b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts
@@ -7,7 +7,6 @@ const routeState = vi.hoisted(() => ({
const routerPush = vi.hoisted(() => vi.fn())
const pollOrderStatus = vi.hoisted(() => vi.fn())
-const verifyOrderPublic = vi.hoisted(() => vi.fn())
const verifyOrder = vi.hoisted(() => vi.fn())
const resolveOrderPublicByResumeToken = vi.hoisted(() => vi.fn())
@@ -38,7 +37,6 @@ vi.mock('@/stores/payment', () => ({
vi.mock('@/api/payment', () => ({
paymentAPI: {
- verifyOrderPublic,
verifyOrder,
resolveOrderPublicByResumeToken,
},
@@ -67,7 +65,6 @@ describe('PaymentResultView', () => {
routeState.query = {}
routerPush.mockReset()
pollOrderStatus.mockReset()
- verifyOrderPublic.mockReset()
verifyOrder.mockReset()
resolveOrderPublicByResumeToken.mockReset()
window.localStorage.clear()
@@ -110,7 +107,6 @@ describe('PaymentResultView', () => {
await flushPromises()
expect(pollOrderStatus).toHaveBeenCalledWith(42)
- expect(verifyOrderPublic).not.toHaveBeenCalled()
expect(wrapper.text()).toContain('payment.result.processing')
expect(wrapper.text()).not.toContain('payment.result.success')
expect(wrapper.text()).not.toContain('payment.result.failed')
@@ -221,7 +217,6 @@ describe('PaymentResultView', () => {
await flushPromises()
expect(resolveOrderPublicByResumeToken).toHaveBeenCalledWith('resume-fail')
- expect(verifyOrderPublic).not.toHaveBeenCalled()
expect(verifyOrder).not.toHaveBeenCalled()
})
@@ -241,7 +236,6 @@ describe('PaymentResultView', () => {
await flushPromises()
- expect(verifyOrderPublic).not.toHaveBeenCalled()
expect(verifyOrder).not.toHaveBeenCalled()
})
@@ -260,7 +254,6 @@ describe('PaymentResultView', () => {
await flushPromises()
- expect(verifyOrderPublic).not.toHaveBeenCalled()
expect(verifyOrder).not.toHaveBeenCalled()
})
@@ -284,7 +277,6 @@ describe('PaymentResultView', () => {
expect(resolveOrderPublicByResumeToken).toHaveBeenCalledWith('resume-77')
expect(wrapper.text()).toContain('payment.result.success')
- expect(verifyOrderPublic).not.toHaveBeenCalled()
})
it('normalizes aliased payment methods before rendering the label', async () => {
--
GitLab
From 440536a93da91405472cd0072728b70ad4d9a733 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 11:41:14 +0800
Subject: [PATCH 130/261] docs: align wechat payment required fields
---
docs/PAYMENT.md | 4 ++--
docs/PAYMENT_CN.md | 4 ++--
2 files changed, 4 insertions(+), 4 deletions(-)
diff --git a/docs/PAYMENT.md b/docs/PAYMENT.md
index 2cc8f566..9322f7bf 100644
--- a/docs/PAYMENT.md
+++ b/docs/PAYMENT.md
@@ -141,8 +141,8 @@ Direct integration with WeChat Pay APIv3. Supports Native QR code payment, H5 pa
| **Merchant API Private Key** | Merchant API private key (PEM format) | Yes |
| **APIv3 Key** | 32-byte APIv3 key | Yes |
| **WeChat Pay Public Key** | WeChat Pay public key (PEM format) | Yes |
-| **WeChat Pay Public Key ID** | WeChat Pay public key ID | No |
-| **Certificate Serial Number** | Merchant certificate serial number | No |
+| **WeChat Pay Public Key ID** | WeChat Pay public key ID | Yes |
+| **Certificate Serial Number** | Merchant certificate serial number | Yes |
### Stripe
diff --git a/docs/PAYMENT_CN.md b/docs/PAYMENT_CN.md
index 95bfb990..0fbc198a 100644
--- a/docs/PAYMENT_CN.md
+++ b/docs/PAYMENT_CN.md
@@ -141,8 +141,8 @@ Sub2API 内置支付系统,支持用户自助充值,无需部署独立的支
| **商户 API 私钥** | 商户 API 私钥(PEM 格式) | 是 |
| **APIv3 密钥** | 32 位 APIv3 密钥 | 是 |
| **微信支付公钥** | 微信支付公钥(PEM 格式) | 是 |
-| **微信支付公钥 ID** | 微信支付公钥 ID | 否 |
-| **商户证书序列号** | 商户证书序列号 | 否 |
+| **微信支付公钥 ID** | 微信支付公钥 ID | 是 |
+| **商户证书序列号** | 商户证书序列号 | 是 |
### Stripe
--
GitLab
From 960b2bb8e61ed33a7679a00eec5faff2490b5637 Mon Sep 17 00:00:00 2001
From: shaw
Date: Tue, 21 Apr 2026 11:14:40 +0800
Subject: [PATCH 131/261] feat(legal): add CLA with automated GitHub Actions
enforcement
Introduce Individual Contributor License Agreement (ICLA) to enable
dual licensing (LGPL-V3 open source + future closed-source releases).
- CLA.md: Apache ICLA-style license grant with moral rights waiver,
patent license, electronic signature clause, and assignability
- .github/workflows/cla.yml: CLA Assistant Lite bot that auto-checks
PRs, posts signing prompts, and stores signatures on a separate
`cla-signatures` branch to keep main branch history clean
---
.github/workflows/cla.yml | 59 +++++++++++++++++++++++++++++++
CLA.md | 73 +++++++++++++++++++++++++++++++++++++++
2 files changed, 132 insertions(+)
create mode 100644 .github/workflows/cla.yml
create mode 100644 CLA.md
diff --git a/.github/workflows/cla.yml b/.github/workflows/cla.yml
new file mode 100644
index 00000000..67c8d6e9
--- /dev/null
+++ b/.github/workflows/cla.yml
@@ -0,0 +1,59 @@
+name: "CLA Assistant"
+
+on:
+ issue_comment:
+ types: [created]
+ pull_request_target:
+ types: [opened, reopened, closed, synchronize]
+
+permissions:
+ actions: write
+ contents: write
+ pull-requests: write
+ statuses: write
+
+jobs:
+ cla-check:
+ if: |
+ github.event_name == 'issue_comment' ||
+ (github.event_name == 'pull_request_target' && github.event.action != 'closed')
+ runs-on: ubuntu-latest
+ steps:
+ - name: "CLA Assistant"
+ if: |
+ (github.event.comment.body == 'recheck' ||
+ github.event.comment.body == 'I have read the CLA Document and I hereby sign the CLA') ||
+ github.event_name == 'pull_request_target'
+ uses: contributor-assistant/github-action@v2.6.1
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ with:
+ path-to-signatures: "cla.json"
+ path-to-document: "https://github.com/Wei-Shaw/sub2api/blob/main/CLA.md"
+ branch: "cla-signatures"
+ allowlist: "dependabot[bot],renovate[bot],bot*"
+ lock-pullrequest-aftermerge: false
+ custom-notsigned-prcomment: |
+ Thank you for your contribution! Before we can merge this PR, we need $you to sign our [Contributor License Agreement (CLA)](https://github.com/Wei-Shaw/sub2api/blob/main/CLA.md).
+
+ **To sign**, please reply with the following comment:
+
+ > I have read the CLA Document and I hereby sign the CLA
+
+ You only need to sign once — it will be valid for all your future contributions to this project.
+ custom-pr-sign-comment: "I have read the CLA Document and I hereby sign the CLA"
+ custom-allsigned-prcomment: "All contributors have signed the CLA. ✅"
+
+ cla-lock:
+ if: github.event_name == 'pull_request_target' && github.event.action == 'closed' && github.event.pull_request.merged == true
+ runs-on: ubuntu-latest
+ steps:
+ - name: "Lock merged PR"
+ uses: contributor-assistant/github-action@v2.6.1
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ with:
+ path-to-signatures: "cla.json"
+ path-to-document: "https://github.com/Wei-Shaw/sub2api/blob/main/CLA.md"
+ branch: "cla-signatures"
+ lock-pullrequest-aftermerge: true
diff --git a/CLA.md b/CLA.md
new file mode 100644
index 00000000..ed0d74b8
--- /dev/null
+++ b/CLA.md
@@ -0,0 +1,73 @@
+# Sub2API Individual Contributor License Agreement (v1.0)
+
+Thank you for your interest in contributing to Sub2API ("the Project"). This Contributor License Agreement ("Agreement") documents the rights granted by contributors to the Project.
+
+By signing this Agreement, you accept and agree to the following terms and conditions for your present and future contributions submitted to the Project.
+
+## 1. Definitions
+
+- **"You" (or "Your")** means the copyright owner or legal entity authorized by the copyright owner that is making this Agreement.
+- **"Contribution"** means any original work of authorship, including any modifications or additions to an existing work, that is intentionally submitted by You to the Project for inclusion in, or documentation of, any of the products owned or managed by the Project. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Project or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Project for the purpose of discussing and improving the Project, but excluding communication that is conspicuously marked or otherwise designated in writing by You as "Not a Contribution."
+- **"Project Owner"** means Wesley Liddick, or any individual or legal entity to whom Wesley Liddick has explicitly assigned or transferred ownership of the Project in writing, and their respective successors and assigns.
+
+## 2. Grant of Copyright License
+
+Subject to the terms and conditions of this Agreement, You hereby grant to the Project Owner a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare derivative works of, publicly display, publicly perform, sublicense, and distribute Your Contributions and such derivative works. This license includes, without limitation, the right to sublicense, assign, and transfer these rights to any third party, including without limitation any successor, assignee, or acquiring entity of the Project or the Project Owner, and to use Your Contributions under any license, including proprietary or commercial licenses.
+
+## 3. Moral Rights
+
+To the fullest extent permitted by applicable law, You irrevocably waive and agree not to assert any moral rights (including rights of attribution and integrity) that You may have in Your Contributions, and agree that the Project Owner and its licensees may use, modify, and distribute Your Contributions without attribution or other obligations arising from moral rights.
+
+## 4. Grant of Patent License
+
+Subject to the terms and conditions of this Agreement, You hereby grant to the Project Owner a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer Your Contributions, where such license applies only to those patent claims licensable by You that are necessarily infringed by Your Contribution(s) alone or by combination of Your Contribution(s) with the Project to which such Contribution(s) was submitted.
+
+## 5. Representations and Warranties
+
+You represent and warrant that:
+
+(a) You are legally entitled to grant the above licenses.
+
+(b) If Your employer(s) has rights to intellectual property that You create that includes Your Contributions, You have received permission to make Contributions on behalf of that employer, or that Your employer has waived such rights for Your Contributions to the Project.
+
+(c) Each of Your Contributions is Your original creation, or You have sufficient rights to submit it under the terms of this Agreement. You agree to provide, upon request, reasonable documentation or explanation of any third-party materials included in Your Contributions.
+
+## 6. No Warranty
+
+Your Contributions are provided on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are not expected to provide support for Your Contributions, except to the extent You desire to provide support.
+
+## 7. No Obligation
+
+You understand that the decision to include Your Contribution in any product or project is entirely at the discretion of the Project Owner, and this Agreement does not obligate the Project Owner to use Your Contribution.
+
+## 8. Retention of Rights
+
+You retain ownership of the copyright in Your Contributions. This Agreement does not transfer any copyright or other intellectual property rights from You to the Project Owner. This Agreement only grants the licenses described above.
+
+## 9. Term and Termination
+
+This Agreement shall remain in effect indefinitely. You may terminate this Agreement prospectively by providing written notice to the Project Owner, but such termination shall not affect the licenses granted for Contributions submitted prior to the effective date of termination. The licenses granted herein for Contributions submitted prior to termination are perpetual and irrevocable.
+
+## 10. Electronic Signature
+
+You agree that Your electronic signature (including but not limited to typing a specific phrase in a pull request, issue, or other electronic communication) is legally binding and has the same force and effect as a handwritten signature. You consent to the use of electronic means to enter into this Agreement and acknowledge that this Agreement is enforceable as if executed in a traditional written format.
+
+## 11. General Provisions
+
+**Entire Agreement.** This Agreement constitutes the entire agreement between You and the Project Owner with respect to Your Contributions and supersedes all prior or contemporaneous understandings regarding such subject matter.
+
+**Severability.** If any provision of this Agreement is held to be unenforceable or invalid, that provision will be enforced to the maximum extent possible and the remaining provisions will remain in full force and effect.
+
+**No Waiver.** The failure of the Project Owner to enforce any provision of this Agreement shall not constitute a waiver of that provision or any other provision.
+
+**Amendment.** This Agreement may only be modified by a written instrument signed by both parties. Modifications to this Agreement apply only to Contributions submitted after the modified Agreement is published and accepted by You. Prior Contributions remain governed by the version of the Agreement in effect at the time of submission.
+
+**Notification.** Notices under this Agreement shall be sent to the Project Owner via a GitHub issue on the Project repository. Notices are effective upon receipt.
+
+---
+
+**By signing this CLA, you acknowledge that you have read and understood this Agreement and agree to be bound by its terms.**
+
+To sign, reply in the pull request with:
+
+> I have read the CLA Document and I hereby sign the CLA
--
GitLab
From 561405ab0033ef795cdc1fc752fabcfcbc333b65 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 12:41:27 +0800
Subject: [PATCH 132/261] feat: add payment order provider snapshots
---
backend/ent/migrate/schema.go | 15 +--
backend/ent/mutation.go | 75 ++++++++++-
backend/ent/paymentorder.go | 16 +++
backend/ent/paymentorder/paymentorder.go | 3 +
backend/ent/paymentorder/where.go | 10 ++
backend/ent/paymentorder_create.go | 70 +++++++++++
backend/ent/paymentorder_update.go | 36 ++++++
backend/ent/runtime/runtime.go | 16 +--
backend/ent/schema/payment_order.go | 3 +
.../internal/handler/admin/payment_handler.go | 25 +++-
backend/internal/handler/payment_handler.go | 27 +++-
backend/internal/service/payment_order.go | 49 +++++++-
.../payment_order_provider_snapshot_test.go | 116 ++++++++++++++++++
...17_add_payment_order_provider_snapshot.sql | 2 +
14 files changed, 440 insertions(+), 23 deletions(-)
create mode 100644 backend/internal/service/payment_order_provider_snapshot_test.go
create mode 100644 backend/migrations/117_add_payment_order_provider_snapshot.sql
diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go
index 230ea060..81f6a664 100644
--- a/backend/ent/migrate/schema.go
+++ b/backend/ent/migrate/schema.go
@@ -655,6 +655,7 @@ var (
{Name: "subscription_days", Type: field.TypeInt, Nullable: true},
{Name: "provider_instance_id", Type: field.TypeString, Nullable: true, Size: 64},
{Name: "provider_key", Type: field.TypeString, Nullable: true, Size: 30},
+ {Name: "provider_snapshot", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
{Name: "status", Type: field.TypeString, Size: 30, Default: "PENDING"},
{Name: "refund_amount", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,2)"}},
{Name: "refund_reason", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
@@ -683,7 +684,7 @@ var (
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "payment_orders_users_payment_orders",
- Columns: []*schema.Column{PaymentOrdersColumns[38]},
+ Columns: []*schema.Column{PaymentOrdersColumns[39]},
RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction,
},
@@ -697,32 +698,32 @@ var (
{
Name: "paymentorder_user_id",
Unique: false,
- Columns: []*schema.Column{PaymentOrdersColumns[38]},
+ Columns: []*schema.Column{PaymentOrdersColumns[39]},
},
{
Name: "paymentorder_status",
Unique: false,
- Columns: []*schema.Column{PaymentOrdersColumns[20]},
+ Columns: []*schema.Column{PaymentOrdersColumns[21]},
},
{
Name: "paymentorder_expires_at",
Unique: false,
- Columns: []*schema.Column{PaymentOrdersColumns[28]},
+ Columns: []*schema.Column{PaymentOrdersColumns[29]},
},
{
Name: "paymentorder_created_at",
Unique: false,
- Columns: []*schema.Column{PaymentOrdersColumns[36]},
+ Columns: []*schema.Column{PaymentOrdersColumns[37]},
},
{
Name: "paymentorder_paid_at",
Unique: false,
- Columns: []*schema.Column{PaymentOrdersColumns[29]},
+ Columns: []*schema.Column{PaymentOrdersColumns[30]},
},
{
Name: "paymentorder_payment_type_paid_at",
Unique: false,
- Columns: []*schema.Column{PaymentOrdersColumns[9], PaymentOrdersColumns[29]},
+ Columns: []*schema.Column{PaymentOrdersColumns[9], PaymentOrdersColumns[30]},
},
{
Name: "paymentorder_order_type",
diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go
index 5227015c..ec4a4070 100644
--- a/backend/ent/mutation.go
+++ b/backend/ent/mutation.go
@@ -15386,6 +15386,7 @@ type PaymentOrderMutation struct {
addsubscription_days *int
provider_instance_id *string
provider_key *string
+ provider_snapshot *map[string]interface{}
status *string
refund_amount *float64
addrefund_amount *float64
@@ -16471,6 +16472,55 @@ func (m *PaymentOrderMutation) ResetProviderKey() {
delete(m.clearedFields, paymentorder.FieldProviderKey)
}
+// SetProviderSnapshot sets the "provider_snapshot" field.
+func (m *PaymentOrderMutation) SetProviderSnapshot(value map[string]interface{}) {
+ m.provider_snapshot = &value
+}
+
+// ProviderSnapshot returns the value of the "provider_snapshot" field in the mutation.
+func (m *PaymentOrderMutation) ProviderSnapshot() (r map[string]interface{}, exists bool) {
+ v := m.provider_snapshot
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderSnapshot returns the old "provider_snapshot" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldProviderSnapshot(ctx context.Context) (v map[string]interface{}, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderSnapshot is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderSnapshot requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderSnapshot: %w", err)
+ }
+ return oldValue.ProviderSnapshot, nil
+}
+
+// ClearProviderSnapshot clears the value of the "provider_snapshot" field.
+func (m *PaymentOrderMutation) ClearProviderSnapshot() {
+ m.provider_snapshot = nil
+ m.clearedFields[paymentorder.FieldProviderSnapshot] = struct{}{}
+}
+
+// ProviderSnapshotCleared returns if the "provider_snapshot" field was cleared in this mutation.
+func (m *PaymentOrderMutation) ProviderSnapshotCleared() bool {
+ _, ok := m.clearedFields[paymentorder.FieldProviderSnapshot]
+ return ok
+}
+
+// ResetProviderSnapshot resets all changes to the "provider_snapshot" field.
+func (m *PaymentOrderMutation) ResetProviderSnapshot() {
+ m.provider_snapshot = nil
+ delete(m.clearedFields, paymentorder.FieldProviderSnapshot)
+}
+
// SetStatus sets the "status" field.
func (m *PaymentOrderMutation) SetStatus(s string) {
m.status = &s
@@ -17330,7 +17380,7 @@ func (m *PaymentOrderMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *PaymentOrderMutation) Fields() []string {
- fields := make([]string, 0, 38)
+ fields := make([]string, 0, 39)
if m.user != nil {
fields = append(fields, paymentorder.FieldUserID)
}
@@ -17391,6 +17441,9 @@ func (m *PaymentOrderMutation) Fields() []string {
if m.provider_key != nil {
fields = append(fields, paymentorder.FieldProviderKey)
}
+ if m.provider_snapshot != nil {
+ fields = append(fields, paymentorder.FieldProviderSnapshot)
+ }
if m.status != nil {
fields = append(fields, paymentorder.FieldStatus)
}
@@ -17493,6 +17546,8 @@ func (m *PaymentOrderMutation) Field(name string) (ent.Value, bool) {
return m.ProviderInstanceID()
case paymentorder.FieldProviderKey:
return m.ProviderKey()
+ case paymentorder.FieldProviderSnapshot:
+ return m.ProviderSnapshot()
case paymentorder.FieldStatus:
return m.Status()
case paymentorder.FieldRefundAmount:
@@ -17578,6 +17633,8 @@ func (m *PaymentOrderMutation) OldField(ctx context.Context, name string) (ent.V
return m.OldProviderInstanceID(ctx)
case paymentorder.FieldProviderKey:
return m.OldProviderKey(ctx)
+ case paymentorder.FieldProviderSnapshot:
+ return m.OldProviderSnapshot(ctx)
case paymentorder.FieldStatus:
return m.OldStatus(ctx)
case paymentorder.FieldRefundAmount:
@@ -17763,6 +17820,13 @@ func (m *PaymentOrderMutation) SetField(name string, value ent.Value) error {
}
m.SetProviderKey(v)
return nil
+ case paymentorder.FieldProviderSnapshot:
+ v, ok := value.(map[string]interface{})
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderSnapshot(v)
+ return nil
case paymentorder.FieldStatus:
v, ok := value.(string)
if !ok {
@@ -18033,6 +18097,9 @@ func (m *PaymentOrderMutation) ClearedFields() []string {
if m.FieldCleared(paymentorder.FieldProviderKey) {
fields = append(fields, paymentorder.FieldProviderKey)
}
+ if m.FieldCleared(paymentorder.FieldProviderSnapshot) {
+ fields = append(fields, paymentorder.FieldProviderSnapshot)
+ }
if m.FieldCleared(paymentorder.FieldRefundReason) {
fields = append(fields, paymentorder.FieldRefundReason)
}
@@ -18104,6 +18171,9 @@ func (m *PaymentOrderMutation) ClearField(name string) error {
case paymentorder.FieldProviderKey:
m.ClearProviderKey()
return nil
+ case paymentorder.FieldProviderSnapshot:
+ m.ClearProviderSnapshot()
+ return nil
case paymentorder.FieldRefundReason:
m.ClearRefundReason()
return nil
@@ -18202,6 +18272,9 @@ func (m *PaymentOrderMutation) ResetField(name string) error {
case paymentorder.FieldProviderKey:
m.ResetProviderKey()
return nil
+ case paymentorder.FieldProviderSnapshot:
+ m.ResetProviderSnapshot()
+ return nil
case paymentorder.FieldStatus:
m.ResetStatus()
return nil
diff --git a/backend/ent/paymentorder.go b/backend/ent/paymentorder.go
index a58823ee..b131b8c8 100644
--- a/backend/ent/paymentorder.go
+++ b/backend/ent/paymentorder.go
@@ -3,6 +3,7 @@
package ent
import (
+ "encoding/json"
"fmt"
"strings"
"time"
@@ -58,6 +59,8 @@ type PaymentOrder struct {
ProviderInstanceID *string `json:"provider_instance_id,omitempty"`
// ProviderKey holds the value of the "provider_key" field.
ProviderKey *string `json:"provider_key,omitempty"`
+ // ProviderSnapshot holds the value of the "provider_snapshot" field.
+ ProviderSnapshot map[string]interface{} `json:"provider_snapshot,omitempty"`
// Status holds the value of the "status" field.
Status string `json:"status,omitempty"`
// RefundAmount holds the value of the "refund_amount" field.
@@ -125,6 +128,8 @@ func (*PaymentOrder) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
+ case paymentorder.FieldProviderSnapshot:
+ values[i] = new([]byte)
case paymentorder.FieldForceRefund:
values[i] = new(sql.NullBool)
case paymentorder.FieldAmount, paymentorder.FieldPayAmount, paymentorder.FieldFeeRate, paymentorder.FieldRefundAmount:
@@ -285,6 +290,14 @@ func (_m *PaymentOrder) assignValues(columns []string, values []any) error {
_m.ProviderKey = new(string)
*_m.ProviderKey = value.String
}
+ case paymentorder.FieldProviderSnapshot:
+ if value, ok := values[i].(*[]byte); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_snapshot", values[i])
+ } else if value != nil && len(*value) > 0 {
+ if err := json.Unmarshal(*value, &_m.ProviderSnapshot); err != nil {
+ return fmt.Errorf("unmarshal field provider_snapshot: %w", err)
+ }
+ }
case paymentorder.FieldStatus:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field status", values[i])
@@ -522,6 +535,9 @@ func (_m *PaymentOrder) String() string {
builder.WriteString(*v)
}
builder.WriteString(", ")
+ builder.WriteString("provider_snapshot=")
+ builder.WriteString(fmt.Sprintf("%v", _m.ProviderSnapshot))
+ builder.WriteString(", ")
builder.WriteString("status=")
builder.WriteString(_m.Status)
builder.WriteString(", ")
diff --git a/backend/ent/paymentorder/paymentorder.go b/backend/ent/paymentorder/paymentorder.go
index af9b1422..62883794 100644
--- a/backend/ent/paymentorder/paymentorder.go
+++ b/backend/ent/paymentorder/paymentorder.go
@@ -54,6 +54,8 @@ const (
FieldProviderInstanceID = "provider_instance_id"
// FieldProviderKey holds the string denoting the provider_key field in the database.
FieldProviderKey = "provider_key"
+ // FieldProviderSnapshot holds the string denoting the provider_snapshot field in the database.
+ FieldProviderSnapshot = "provider_snapshot"
// FieldStatus holds the string denoting the status field in the database.
FieldStatus = "status"
// FieldRefundAmount holds the string denoting the refund_amount field in the database.
@@ -126,6 +128,7 @@ var Columns = []string{
FieldSubscriptionDays,
FieldProviderInstanceID,
FieldProviderKey,
+ FieldProviderSnapshot,
FieldStatus,
FieldRefundAmount,
FieldRefundReason,
diff --git a/backend/ent/paymentorder/where.go b/backend/ent/paymentorder/where.go
index 0f6b74a0..e96bf51e 100644
--- a/backend/ent/paymentorder/where.go
+++ b/backend/ent/paymentorder/where.go
@@ -1440,6 +1440,16 @@ func ProviderKeyContainsFold(v string) predicate.PaymentOrder {
return predicate.PaymentOrder(sql.FieldContainsFold(FieldProviderKey, v))
}
+// ProviderSnapshotIsNil applies the IsNil predicate on the "provider_snapshot" field.
+func ProviderSnapshotIsNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIsNull(FieldProviderSnapshot))
+}
+
+// ProviderSnapshotNotNil applies the NotNil predicate on the "provider_snapshot" field.
+func ProviderSnapshotNotNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotNull(FieldProviderSnapshot))
+}
+
// StatusEQ applies the EQ predicate on the "status" field.
func StatusEQ(v string) predicate.PaymentOrder {
return predicate.PaymentOrder(sql.FieldEQ(FieldStatus, v))
diff --git a/backend/ent/paymentorder_create.go b/backend/ent/paymentorder_create.go
index 497ba52c..3ee24f8e 100644
--- a/backend/ent/paymentorder_create.go
+++ b/backend/ent/paymentorder_create.go
@@ -239,6 +239,12 @@ func (_c *PaymentOrderCreate) SetNillableProviderKey(v *string) *PaymentOrderCre
return _c
}
+// SetProviderSnapshot sets the "provider_snapshot" field.
+func (_c *PaymentOrderCreate) SetProviderSnapshot(v map[string]interface{}) *PaymentOrderCreate {
+ _c.mutation.SetProviderSnapshot(v)
+ return _c
+}
+
// SetStatus sets the "status" field.
func (_c *PaymentOrderCreate) SetStatus(v string) *PaymentOrderCreate {
_c.mutation.SetStatus(v)
@@ -771,6 +777,10 @@ func (_c *PaymentOrderCreate) createSpec() (*PaymentOrder, *sqlgraph.CreateSpec)
_spec.SetField(paymentorder.FieldProviderKey, field.TypeString, value)
_node.ProviderKey = &value
}
+ if value, ok := _c.mutation.ProviderSnapshot(); ok {
+ _spec.SetField(paymentorder.FieldProviderSnapshot, field.TypeJSON, value)
+ _node.ProviderSnapshot = value
+ }
if value, ok := _c.mutation.Status(); ok {
_spec.SetField(paymentorder.FieldStatus, field.TypeString, value)
_node.Status = value
@@ -1242,6 +1252,24 @@ func (u *PaymentOrderUpsert) ClearProviderKey() *PaymentOrderUpsert {
return u
}
+// SetProviderSnapshot sets the "provider_snapshot" field.
+func (u *PaymentOrderUpsert) SetProviderSnapshot(v map[string]interface{}) *PaymentOrderUpsert {
+ u.Set(paymentorder.FieldProviderSnapshot, v)
+ return u
+}
+
+// UpdateProviderSnapshot sets the "provider_snapshot" field to the value that was provided on create.
+func (u *PaymentOrderUpsert) UpdateProviderSnapshot() *PaymentOrderUpsert {
+ u.SetExcluded(paymentorder.FieldProviderSnapshot)
+ return u
+}
+
+// ClearProviderSnapshot clears the value of the "provider_snapshot" field.
+func (u *PaymentOrderUpsert) ClearProviderSnapshot() *PaymentOrderUpsert {
+ u.SetNull(paymentorder.FieldProviderSnapshot)
+ return u
+}
+
// SetStatus sets the "status" field.
func (u *PaymentOrderUpsert) SetStatus(v string) *PaymentOrderUpsert {
u.Set(paymentorder.FieldStatus, v)
@@ -1942,6 +1970,27 @@ func (u *PaymentOrderUpsertOne) ClearProviderKey() *PaymentOrderUpsertOne {
})
}
+// SetProviderSnapshot sets the "provider_snapshot" field.
+func (u *PaymentOrderUpsertOne) SetProviderSnapshot(v map[string]interface{}) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetProviderSnapshot(v)
+ })
+}
+
+// UpdateProviderSnapshot sets the "provider_snapshot" field to the value that was provided on create.
+func (u *PaymentOrderUpsertOne) UpdateProviderSnapshot() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateProviderSnapshot()
+ })
+}
+
+// ClearProviderSnapshot clears the value of the "provider_snapshot" field.
+func (u *PaymentOrderUpsertOne) ClearProviderSnapshot() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearProviderSnapshot()
+ })
+}
+
// SetStatus sets the "status" field.
func (u *PaymentOrderUpsertOne) SetStatus(v string) *PaymentOrderUpsertOne {
return u.Update(func(s *PaymentOrderUpsert) {
@@ -2853,6 +2902,27 @@ func (u *PaymentOrderUpsertBulk) ClearProviderKey() *PaymentOrderUpsertBulk {
})
}
+// SetProviderSnapshot sets the "provider_snapshot" field.
+func (u *PaymentOrderUpsertBulk) SetProviderSnapshot(v map[string]interface{}) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetProviderSnapshot(v)
+ })
+}
+
+// UpdateProviderSnapshot sets the "provider_snapshot" field to the value that was provided on create.
+func (u *PaymentOrderUpsertBulk) UpdateProviderSnapshot() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateProviderSnapshot()
+ })
+}
+
+// ClearProviderSnapshot clears the value of the "provider_snapshot" field.
+func (u *PaymentOrderUpsertBulk) ClearProviderSnapshot() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearProviderSnapshot()
+ })
+}
+
// SetStatus sets the "status" field.
func (u *PaymentOrderUpsertBulk) SetStatus(v string) *PaymentOrderUpsertBulk {
return u.Update(func(s *PaymentOrderUpsert) {
diff --git a/backend/ent/paymentorder_update.go b/backend/ent/paymentorder_update.go
index 9a901415..378e0dad 100644
--- a/backend/ent/paymentorder_update.go
+++ b/backend/ent/paymentorder_update.go
@@ -405,6 +405,18 @@ func (_u *PaymentOrderUpdate) ClearProviderKey() *PaymentOrderUpdate {
return _u
}
+// SetProviderSnapshot sets the "provider_snapshot" field.
+func (_u *PaymentOrderUpdate) SetProviderSnapshot(v map[string]interface{}) *PaymentOrderUpdate {
+ _u.mutation.SetProviderSnapshot(v)
+ return _u
+}
+
+// ClearProviderSnapshot clears the value of the "provider_snapshot" field.
+func (_u *PaymentOrderUpdate) ClearProviderSnapshot() *PaymentOrderUpdate {
+ _u.mutation.ClearProviderSnapshot()
+ return _u
+}
+
// SetStatus sets the "status" field.
func (_u *PaymentOrderUpdate) SetStatus(v string) *PaymentOrderUpdate {
_u.mutation.SetStatus(v)
@@ -941,6 +953,12 @@ func (_u *PaymentOrderUpdate) sqlSave(ctx context.Context) (_node int, err error
if _u.mutation.ProviderKeyCleared() {
_spec.ClearField(paymentorder.FieldProviderKey, field.TypeString)
}
+ if value, ok := _u.mutation.ProviderSnapshot(); ok {
+ _spec.SetField(paymentorder.FieldProviderSnapshot, field.TypeJSON, value)
+ }
+ if _u.mutation.ProviderSnapshotCleared() {
+ _spec.ClearField(paymentorder.FieldProviderSnapshot, field.TypeJSON)
+ }
if value, ok := _u.mutation.Status(); ok {
_spec.SetField(paymentorder.FieldStatus, field.TypeString, value)
}
@@ -1450,6 +1468,18 @@ func (_u *PaymentOrderUpdateOne) ClearProviderKey() *PaymentOrderUpdateOne {
return _u
}
+// SetProviderSnapshot sets the "provider_snapshot" field.
+func (_u *PaymentOrderUpdateOne) SetProviderSnapshot(v map[string]interface{}) *PaymentOrderUpdateOne {
+ _u.mutation.SetProviderSnapshot(v)
+ return _u
+}
+
+// ClearProviderSnapshot clears the value of the "provider_snapshot" field.
+func (_u *PaymentOrderUpdateOne) ClearProviderSnapshot() *PaymentOrderUpdateOne {
+ _u.mutation.ClearProviderSnapshot()
+ return _u
+}
+
// SetStatus sets the "status" field.
func (_u *PaymentOrderUpdateOne) SetStatus(v string) *PaymentOrderUpdateOne {
_u.mutation.SetStatus(v)
@@ -2016,6 +2046,12 @@ func (_u *PaymentOrderUpdateOne) sqlSave(ctx context.Context) (_node *PaymentOrd
if _u.mutation.ProviderKeyCleared() {
_spec.ClearField(paymentorder.FieldProviderKey, field.TypeString)
}
+ if value, ok := _u.mutation.ProviderSnapshot(); ok {
+ _spec.SetField(paymentorder.FieldProviderSnapshot, field.TypeJSON, value)
+ }
+ if _u.mutation.ProviderSnapshotCleared() {
+ _spec.ClearField(paymentorder.FieldProviderSnapshot, field.TypeJSON)
+ }
if value, ok := _u.mutation.Status(); ok {
_spec.SetField(paymentorder.FieldStatus, field.TypeString, value)
}
diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go
index b7118ac9..bdb7f7a9 100644
--- a/backend/ent/runtime/runtime.go
+++ b/backend/ent/runtime/runtime.go
@@ -728,37 +728,37 @@ func init() {
// paymentorder.ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
paymentorder.ProviderKeyValidator = paymentorderDescProviderKey.Validators[0].(func(string) error)
// paymentorderDescStatus is the schema descriptor for status field.
- paymentorderDescStatus := paymentorderFields[20].Descriptor()
+ paymentorderDescStatus := paymentorderFields[21].Descriptor()
// paymentorder.DefaultStatus holds the default value on creation for the status field.
paymentorder.DefaultStatus = paymentorderDescStatus.Default.(string)
// paymentorder.StatusValidator is a validator for the "status" field. It is called by the builders before save.
paymentorder.StatusValidator = paymentorderDescStatus.Validators[0].(func(string) error)
// paymentorderDescRefundAmount is the schema descriptor for refund_amount field.
- paymentorderDescRefundAmount := paymentorderFields[21].Descriptor()
+ paymentorderDescRefundAmount := paymentorderFields[22].Descriptor()
// paymentorder.DefaultRefundAmount holds the default value on creation for the refund_amount field.
paymentorder.DefaultRefundAmount = paymentorderDescRefundAmount.Default.(float64)
// paymentorderDescForceRefund is the schema descriptor for force_refund field.
- paymentorderDescForceRefund := paymentorderFields[24].Descriptor()
+ paymentorderDescForceRefund := paymentorderFields[25].Descriptor()
// paymentorder.DefaultForceRefund holds the default value on creation for the force_refund field.
paymentorder.DefaultForceRefund = paymentorderDescForceRefund.Default.(bool)
// paymentorderDescRefundRequestedBy is the schema descriptor for refund_requested_by field.
- paymentorderDescRefundRequestedBy := paymentorderFields[27].Descriptor()
+ paymentorderDescRefundRequestedBy := paymentorderFields[28].Descriptor()
// paymentorder.RefundRequestedByValidator is a validator for the "refund_requested_by" field. It is called by the builders before save.
paymentorder.RefundRequestedByValidator = paymentorderDescRefundRequestedBy.Validators[0].(func(string) error)
// paymentorderDescClientIP is the schema descriptor for client_ip field.
- paymentorderDescClientIP := paymentorderFields[33].Descriptor()
+ paymentorderDescClientIP := paymentorderFields[34].Descriptor()
// paymentorder.ClientIPValidator is a validator for the "client_ip" field. It is called by the builders before save.
paymentorder.ClientIPValidator = paymentorderDescClientIP.Validators[0].(func(string) error)
// paymentorderDescSrcHost is the schema descriptor for src_host field.
- paymentorderDescSrcHost := paymentorderFields[34].Descriptor()
+ paymentorderDescSrcHost := paymentorderFields[35].Descriptor()
// paymentorder.SrcHostValidator is a validator for the "src_host" field. It is called by the builders before save.
paymentorder.SrcHostValidator = paymentorderDescSrcHost.Validators[0].(func(string) error)
// paymentorderDescCreatedAt is the schema descriptor for created_at field.
- paymentorderDescCreatedAt := paymentorderFields[36].Descriptor()
+ paymentorderDescCreatedAt := paymentorderFields[37].Descriptor()
// paymentorder.DefaultCreatedAt holds the default value on creation for the created_at field.
paymentorder.DefaultCreatedAt = paymentorderDescCreatedAt.Default.(func() time.Time)
// paymentorderDescUpdatedAt is the schema descriptor for updated_at field.
- paymentorderDescUpdatedAt := paymentorderFields[37].Descriptor()
+ paymentorderDescUpdatedAt := paymentorderFields[38].Descriptor()
// paymentorder.DefaultUpdatedAt holds the default value on creation for the updated_at field.
paymentorder.DefaultUpdatedAt = paymentorderDescUpdatedAt.Default.(func() time.Time)
// paymentorder.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
diff --git a/backend/ent/schema/payment_order.go b/backend/ent/schema/payment_order.go
index 64378de1..5815d032 100644
--- a/backend/ent/schema/payment_order.go
+++ b/backend/ent/schema/payment_order.go
@@ -95,6 +95,9 @@ func (PaymentOrder) Fields() []ent.Field {
Optional().
Nillable().
MaxLen(30),
+ field.JSON("provider_snapshot", map[string]any{}).
+ Optional().
+ SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
// 状态
field.String("status").
diff --git a/backend/internal/handler/admin/payment_handler.go b/backend/internal/handler/admin/payment_handler.go
index b0ed6aed..84359cd9 100644
--- a/backend/internal/handler/admin/payment_handler.go
+++ b/backend/internal/handler/admin/payment_handler.go
@@ -3,6 +3,7 @@ package admin
import (
"strconv"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -66,7 +67,7 @@ func (h *PaymentHandler) ListOrders(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
- response.Paginated(c, orders, int64(total), page, pageSize)
+ response.Paginated(c, sanitizeAdminPaymentOrdersForResponse(orders), int64(total), page, pageSize)
}
// GetOrderDetail returns detailed information about a single order.
@@ -82,7 +83,7 @@ func (h *PaymentHandler) GetOrderDetail(c *gin.Context) {
return
}
auditLogs, _ := h.paymentService.GetOrderAuditLogs(c.Request.Context(), orderID)
- response.Success(c, gin.H{"order": order, "auditLogs": auditLogs})
+ response.Success(c, gin.H{"order": sanitizeAdminPaymentOrderForResponse(order), "auditLogs": auditLogs})
}
// CancelOrder cancels a pending order (admin).
@@ -114,6 +115,26 @@ func (h *PaymentHandler) RetryFulfillment(c *gin.Context) {
response.Success(c, gin.H{"message": "fulfillment retried"})
}
+func sanitizeAdminPaymentOrdersForResponse(orders []*dbent.PaymentOrder) []*dbent.PaymentOrder {
+ if len(orders) == 0 {
+ return orders
+ }
+ out := make([]*dbent.PaymentOrder, 0, len(orders))
+ for _, order := range orders {
+ out = append(out, sanitizeAdminPaymentOrderForResponse(order))
+ }
+ return out
+}
+
+func sanitizeAdminPaymentOrderForResponse(order *dbent.PaymentOrder) *dbent.PaymentOrder {
+ if order == nil {
+ return nil
+ }
+ cloned := *order
+ cloned.ProviderSnapshot = nil
+ return &cloned
+}
+
// AdminProcessRefundRequest is the request body for admin refund processing.
type AdminProcessRefundRequest struct {
Amount float64 `json:"amount"`
diff --git a/backend/internal/handler/payment_handler.go b/backend/internal/handler/payment_handler.go
index 273aea73..0fba4726 100644
--- a/backend/internal/handler/payment_handler.go
+++ b/backend/internal/handler/payment_handler.go
@@ -6,6 +6,7 @@ import (
"strconv"
"strings"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
@@ -327,7 +328,7 @@ func (h *PaymentHandler) GetMyOrders(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
- response.Paginated(c, orders, int64(total), page, pageSize)
+ response.Paginated(c, sanitizePaymentOrdersForResponse(orders), int64(total), page, pageSize)
}
// GetOrder returns a single order for the authenticated user.
@@ -349,7 +350,7 @@ func (h *PaymentHandler) GetOrder(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
- response.Success(c, order)
+ response.Success(c, sanitizePaymentOrderForResponse(order))
}
// CancelOrder cancels a pending order for the authenticated user.
@@ -445,7 +446,7 @@ func (h *PaymentHandler) VerifyOrder(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
- response.Success(c, order)
+ response.Success(c, sanitizePaymentOrderForResponse(order))
}
// PublicOrderResult is the limited order info returned by the public verify endpoint.
@@ -523,6 +524,26 @@ func isMobile(c *gin.Context) bool {
return false
}
+func sanitizePaymentOrdersForResponse(orders []*dbent.PaymentOrder) []*dbent.PaymentOrder {
+ if len(orders) == 0 {
+ return orders
+ }
+ out := make([]*dbent.PaymentOrder, 0, len(orders))
+ for _, order := range orders {
+ out = append(out, sanitizePaymentOrderForResponse(order))
+ }
+ return out
+}
+
+func sanitizePaymentOrderForResponse(order *dbent.PaymentOrder) *dbent.PaymentOrder {
+ if order == nil {
+ return nil
+ }
+ cloned := *order
+ cloned.ProviderSnapshot = nil
+ return &cloned
+}
+
func isWeChatBrowser(c *gin.Context) bool {
return strings.Contains(strings.ToLower(c.GetHeader("User-Agent")), "micromessenger")
}
diff --git a/backend/internal/service/payment_order.go b/backend/internal/service/payment_order.go
index 7d973b92..1f01bc11 100644
--- a/backend/internal/service/payment_order.go
+++ b/backend/internal/service/payment_order.go
@@ -73,7 +73,7 @@ func (s *PaymentService) CreateOrder(ctx context.Context, req CreateOrderRequest
if oauthResp != nil {
return oauthResp, nil
}
- order, err := s.createOrderInTx(ctx, req, user, plan, cfg, orderAmount, limitAmount, feeRate, payAmount)
+ order, err := s.createOrderInTx(ctx, req, user, plan, cfg, orderAmount, limitAmount, feeRate, payAmount, sel)
if err != nil {
return nil, err
}
@@ -122,7 +122,7 @@ func (s *PaymentService) validateSubOrder(ctx context.Context, req CreateOrderRe
return plan, nil
}
-func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderRequest, user *User, plan *dbent.SubscriptionPlan, cfg *PaymentConfig, orderAmount, limitAmount, feeRate, payAmount float64) (*dbent.PaymentOrder, error) {
+func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderRequest, user *User, plan *dbent.SubscriptionPlan, cfg *PaymentConfig, orderAmount, limitAmount, feeRate, payAmount float64, sel *payment.InstanceSelection) (*dbent.PaymentOrder, error) {
tx, err := s.entClient.Tx(ctx)
if err != nil {
return nil, fmt.Errorf("begin transaction: %w", err)
@@ -139,6 +139,13 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq
tm = defaultOrderTimeoutMin
}
exp := time.Now().Add(time.Duration(tm) * time.Minute)
+ providerSnapshot := buildPaymentOrderProviderSnapshot(sel)
+ selectedInstanceID := ""
+ selectedProviderKey := ""
+ if sel != nil {
+ selectedInstanceID = strings.TrimSpace(sel.InstanceID)
+ selectedProviderKey = strings.TrimSpace(sel.ProviderKey)
+ }
b := tx.PaymentOrder.Create().
SetUserID(req.UserID).
SetUserEmail(user.Email).
@@ -159,6 +166,15 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq
if req.SrcURL != "" {
b.SetSrcURL(req.SrcURL)
}
+ if selectedInstanceID != "" {
+ b.SetProviderInstanceID(selectedInstanceID)
+ }
+ if selectedProviderKey != "" {
+ b.SetProviderKey(selectedProviderKey)
+ }
+ if providerSnapshot != nil {
+ b.SetProviderSnapshot(providerSnapshot)
+ }
if plan != nil {
b.SetPlanID(plan.ID).SetSubscriptionGroupID(plan.GroupID).SetSubscriptionDays(psComputeValidityDays(plan.ValidityDays, plan.ValidityUnit))
}
@@ -192,6 +208,35 @@ func (s *PaymentService) checkPendingLimit(ctx context.Context, tx *dbent.Tx, us
return nil
}
+func buildPaymentOrderProviderSnapshot(sel *payment.InstanceSelection) map[string]any {
+ if sel == nil {
+ return nil
+ }
+
+ snapshot := map[string]any{}
+ snapshot["schema_version"] = 1
+
+ instanceID := strings.TrimSpace(sel.InstanceID)
+ if instanceID != "" {
+ snapshot["provider_instance_id"] = instanceID
+ }
+
+ providerKey := strings.TrimSpace(sel.ProviderKey)
+ if providerKey != "" {
+ snapshot["provider_key"] = providerKey
+ }
+
+ paymentMode := strings.TrimSpace(sel.PaymentMode)
+ if paymentMode != "" {
+ snapshot["payment_mode"] = paymentMode
+ }
+
+ if len(snapshot) == 1 {
+ return nil
+ }
+ return snapshot
+}
+
func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, userID int64, amount, limit float64) error {
if limit <= 0 {
return nil
diff --git a/backend/internal/service/payment_order_provider_snapshot_test.go b/backend/internal/service/payment_order_provider_snapshot_test.go
new file mode 100644
index 00000000..c75566bc
--- /dev/null
+++ b/backend/internal/service/payment_order_provider_snapshot_test.go
@@ -0,0 +1,116 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "strconv"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ "github.com/stretchr/testify/require"
+)
+
+func TestBuildPaymentOrderProviderSnapshot_ExcludesSensitiveConfig(t *testing.T) {
+ t.Parallel()
+
+ sel := &payment.InstanceSelection{
+ InstanceID: "12",
+ ProviderKey: payment.TypeWxpay,
+ SupportedTypes: "wxpay,wxpay_direct",
+ PaymentMode: "popup",
+ Config: map[string]string{
+ "privateKey": "secret",
+ "apiV3Key": "secret-v3",
+ "appId": "wx-app-id",
+ },
+ }
+
+ snapshot := buildPaymentOrderProviderSnapshot(sel)
+ require.Equal(t, map[string]any{
+ "schema_version": 1,
+ "provider_instance_id": "12",
+ "provider_key": payment.TypeWxpay,
+ "payment_mode": "popup",
+ }, snapshot)
+ require.NotContains(t, snapshot, "config")
+ require.NotContains(t, snapshot, "privateKey")
+ require.NotContains(t, snapshot, "apiV3Key")
+ require.NotContains(t, snapshot, "supported_types")
+ require.NotContains(t, snapshot, "instance_name")
+}
+
+func TestCreateOrderInTx_WritesProviderSnapshot(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+
+ user, err := client.User.Create().
+ SetEmail("snapshot@example.com").
+ SetPasswordHash("hash").
+ SetUsername("snapshot-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ instance, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeAlipay).
+ SetName("Primary Alipay").
+ SetConfig(`{"secretKey":"do-not-copy"}`).
+ SetSupportedTypes("alipay,alipay_direct").
+ SetPaymentMode("redirect").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ svc := &PaymentService{entClient: client}
+ order, err := svc.createOrderInTx(
+ ctx,
+ CreateOrderRequest{
+ UserID: user.ID,
+ PaymentType: payment.TypeAlipay,
+ OrderType: payment.OrderTypeBalance,
+ ClientIP: "127.0.0.1",
+ SrcHost: "app.example.com",
+ },
+ &User{
+ ID: user.ID,
+ Email: user.Email,
+ Username: user.Username,
+ },
+ nil,
+ &PaymentConfig{
+ MaxPendingOrders: 3,
+ OrderTimeoutMin: 30,
+ },
+ 88,
+ 88,
+ 0,
+ 88,
+ &payment.InstanceSelection{
+ InstanceID: strconv.FormatInt(instance.ID, 10),
+ ProviderKey: payment.TypeAlipay,
+ SupportedTypes: "alipay,alipay_direct",
+ PaymentMode: "redirect",
+ Config: map[string]string{
+ "secretKey": "do-not-copy",
+ },
+ },
+ )
+ require.NoError(t, err)
+ require.Equal(t, strconv.FormatInt(instance.ID, 10), valueOrEmpty(order.ProviderInstanceID))
+ require.Equal(t, payment.TypeAlipay, valueOrEmpty(order.ProviderKey))
+ require.Equal(t, float64(1), order.ProviderSnapshot["schema_version"])
+ require.Equal(t, strconv.FormatInt(instance.ID, 10), order.ProviderSnapshot["provider_instance_id"])
+ require.Equal(t, payment.TypeAlipay, order.ProviderSnapshot["provider_key"])
+ require.Equal(t, "redirect", order.ProviderSnapshot["payment_mode"])
+ require.NotContains(t, order.ProviderSnapshot, "config")
+ require.NotContains(t, order.ProviderSnapshot, "secretKey")
+ require.NotContains(t, order.ProviderSnapshot, "supported_types")
+ require.NotContains(t, order.ProviderSnapshot, "instance_name")
+}
+
+func valueOrEmpty(v *string) string {
+ if v == nil {
+ return ""
+ }
+ return *v
+}
diff --git a/backend/migrations/117_add_payment_order_provider_snapshot.sql b/backend/migrations/117_add_payment_order_provider_snapshot.sql
new file mode 100644
index 00000000..56a5fe2d
--- /dev/null
+++ b/backend/migrations/117_add_payment_order_provider_snapshot.sql
@@ -0,0 +1,2 @@
+ALTER TABLE payment_orders
+ADD COLUMN IF NOT EXISTS provider_snapshot JSONB;
--
GitLab
From 35aeeaa6e17632ab08bd7f7f85d5edc396e4ce0f Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 12:50:55 +0800
Subject: [PATCH 133/261] fix: pin payment read paths to provider snapshots
---
.../internal/service/payment_fulfillment.go | 2 +-
.../service/payment_fulfillment_test.go | 24 +++
.../payment_order_provider_snapshot.go | 115 ++++++++++++++
backend/internal/service/payment_refund.go | 4 +
.../service/payment_webhook_provider.go | 2 +-
.../service/payment_webhook_provider_test.go | 150 ++++++++++++++++++
6 files changed, 295 insertions(+), 2 deletions(-)
create mode 100644 backend/internal/service/payment_order_provider_snapshot.go
diff --git a/backend/internal/service/payment_fulfillment.go b/backend/internal/service/payment_fulfillment.go
index 83bac21d..9cb03cca 100644
--- a/backend/internal/service/payment_fulfillment.go
+++ b/backend/internal/service/payment_fulfillment.go
@@ -45,7 +45,7 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo
if inst, instErr := s.getOrderProviderInstance(ctx, o); instErr == nil && inst != nil {
instanceProviderKey = inst.ProviderKey
}
- expectedProviderKey := expectedNotificationProviderKey(s.registry, o.PaymentType, psStringValue(o.ProviderKey), instanceProviderKey)
+ expectedProviderKey := expectedNotificationProviderKeyForOrder(s.registry, o, instanceProviderKey)
if expectedProviderKey != "" && strings.TrimSpace(pk) != "" && !strings.EqualFold(expectedProviderKey, strings.TrimSpace(pk)) {
s.writeAuditLog(ctx, o.ID, "PAYMENT_PROVIDER_MISMATCH", pk, map[string]any{
"expectedProvider": expectedProviderKey,
diff --git a/backend/internal/service/payment_fulfillment_test.go b/backend/internal/service/payment_fulfillment_test.go
index 712129b0..3ce82973 100644
--- a/backend/internal/service/payment_fulfillment_test.go
+++ b/backend/internal/service/payment_fulfillment_test.go
@@ -7,6 +7,7 @@ import (
"errors"
"testing"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/stretchr/testify/assert"
)
@@ -240,3 +241,26 @@ func TestExpectedNotificationProviderKeyPrefersOrderSnapshotProviderKey(t *testi
expectedNotificationProviderKey(registry, payment.TypeAlipay, payment.TypeEasyPay, ""),
)
}
+
+func TestExpectedNotificationProviderKeyForOrderUsesSnapshotProviderKey(t *testing.T) {
+ t.Parallel()
+
+ registry := payment.NewRegistry()
+ registry.Register(paymentFulfillmentTestProvider{
+ key: payment.TypeAlipay,
+ supportedTypes: []payment.PaymentType{payment.TypeAlipay},
+ })
+
+ order := &dbent.PaymentOrder{
+ PaymentType: payment.TypeAlipay,
+ ProviderSnapshot: map[string]any{
+ "schema_version": 1,
+ "provider_key": payment.TypeEasyPay,
+ },
+ }
+
+ assert.Equal(t,
+ payment.TypeEasyPay,
+ expectedNotificationProviderKeyForOrder(registry, order, ""),
+ )
+}
diff --git a/backend/internal/service/payment_order_provider_snapshot.go b/backend/internal/service/payment_order_provider_snapshot.go
new file mode 100644
index 00000000..9a0aa106
--- /dev/null
+++ b/backend/internal/service/payment_order_provider_snapshot.go
@@ -0,0 +1,115 @@
+package service
+
+import (
+ "context"
+ "fmt"
+ "strconv"
+ "strings"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+)
+
+type paymentOrderProviderSnapshot struct {
+ SchemaVersion int
+ ProviderInstanceID string
+ ProviderKey string
+ PaymentMode string
+}
+
+func psOrderProviderSnapshot(order *dbent.PaymentOrder) *paymentOrderProviderSnapshot {
+ if order == nil || len(order.ProviderSnapshot) == 0 {
+ return nil
+ }
+
+ snapshot := &paymentOrderProviderSnapshot{
+ SchemaVersion: psSnapshotIntValue(order.ProviderSnapshot["schema_version"]),
+ ProviderInstanceID: psSnapshotStringValue(order.ProviderSnapshot["provider_instance_id"]),
+ ProviderKey: psSnapshotStringValue(order.ProviderSnapshot["provider_key"]),
+ PaymentMode: psSnapshotStringValue(order.ProviderSnapshot["payment_mode"]),
+ }
+ if snapshot.SchemaVersion == 0 && snapshot.ProviderInstanceID == "" && snapshot.ProviderKey == "" && snapshot.PaymentMode == "" {
+ return nil
+ }
+ return snapshot
+}
+
+func psSnapshotStringValue(value any) string {
+ switch typed := value.(type) {
+ case string:
+ return strings.TrimSpace(typed)
+ default:
+ return ""
+ }
+}
+
+func psSnapshotIntValue(value any) int {
+ switch typed := value.(type) {
+ case int:
+ return typed
+ case int32:
+ return int(typed)
+ case int64:
+ return int(typed)
+ case float32:
+ return int(typed)
+ case float64:
+ return int(typed)
+ case string:
+ n, err := strconv.Atoi(strings.TrimSpace(typed))
+ if err == nil {
+ return n
+ }
+ }
+ return 0
+}
+
+func (s *PaymentService) resolveSnapshotOrderProviderInstance(ctx context.Context, order *dbent.PaymentOrder, snapshot *paymentOrderProviderSnapshot) (*dbent.PaymentProviderInstance, error) {
+ if s == nil || s.entClient == nil || order == nil || snapshot == nil {
+ return nil, nil
+ }
+
+ snapshotInstanceID := strings.TrimSpace(snapshot.ProviderInstanceID)
+ columnInstanceID := strings.TrimSpace(psStringValue(order.ProviderInstanceID))
+ if snapshotInstanceID == "" {
+ snapshotInstanceID = columnInstanceID
+ }
+ if snapshotInstanceID == "" {
+ return nil, fmt.Errorf("order %d provider snapshot is missing provider_instance_id", order.ID)
+ }
+ if columnInstanceID != "" && snapshot.ProviderInstanceID != "" && !strings.EqualFold(columnInstanceID, snapshot.ProviderInstanceID) {
+ return nil, fmt.Errorf("order %d provider snapshot instance mismatch: snapshot=%s order=%s", order.ID, snapshot.ProviderInstanceID, columnInstanceID)
+ }
+
+ instID, err := strconv.ParseInt(snapshotInstanceID, 10, 64)
+ if err != nil {
+ return nil, fmt.Errorf("order %d provider snapshot instance id is invalid: %s", order.ID, snapshotInstanceID)
+ }
+
+ inst, err := s.entClient.PaymentProviderInstance.Get(ctx, instID)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, fmt.Errorf("order %d provider snapshot instance %s is missing", order.ID, snapshotInstanceID)
+ }
+ return nil, err
+ }
+
+ if snapshot.ProviderKey != "" && !strings.EqualFold(strings.TrimSpace(inst.ProviderKey), snapshot.ProviderKey) {
+ return nil, fmt.Errorf("order %d provider snapshot key mismatch: snapshot=%s instance=%s", order.ID, snapshot.ProviderKey, inst.ProviderKey)
+ }
+
+ return inst, nil
+}
+
+func expectedNotificationProviderKeyForOrder(registry *payment.Registry, order *dbent.PaymentOrder, instanceProviderKey string) string {
+ if order == nil {
+ return strings.TrimSpace(instanceProviderKey)
+ }
+
+ orderProviderKey := psStringValue(order.ProviderKey)
+ if snapshot := psOrderProviderSnapshot(order); snapshot != nil && snapshot.ProviderKey != "" {
+ orderProviderKey = snapshot.ProviderKey
+ }
+
+ return expectedNotificationProviderKey(registry, order.PaymentType, orderProviderKey, instanceProviderKey)
+}
diff --git a/backend/internal/service/payment_refund.go b/backend/internal/service/payment_refund.go
index 57469fa3..fbaeff99 100644
--- a/backend/internal/service/payment_refund.go
+++ b/backend/internal/service/payment_refund.go
@@ -27,6 +27,10 @@ func (s *PaymentService) getOrderProviderInstance(ctx context.Context, o *dbent.
return nil, nil
}
+ if snapshot := psOrderProviderSnapshot(o); snapshot != nil {
+ return s.resolveSnapshotOrderProviderInstance(ctx, o, snapshot)
+ }
+
instIDStr := strings.TrimSpace(psStringValue(o.ProviderInstanceID))
if instIDStr == "" {
return s.resolveUniqueLegacyOrderProviderInstance(ctx, o)
diff --git a/backend/internal/service/payment_webhook_provider.go b/backend/internal/service/payment_webhook_provider.go
index 82dc9ea3..f2da40d9 100644
--- a/backend/internal/service/payment_webhook_provider.go
+++ b/backend/internal/service/payment_webhook_provider.go
@@ -113,7 +113,7 @@ func (s *PaymentService) webhookRegistryFallbackAllowed(ctx context.Context, pro
}
func psHasPinnedProviderInstance(order *dbent.PaymentOrder) bool {
- return order != nil && order.ProviderInstanceID != nil && strings.TrimSpace(*order.ProviderInstanceID) != ""
+ return order != nil && (psOrderProviderSnapshot(order) != nil || (order.ProviderInstanceID != nil && strings.TrimSpace(*order.ProviderInstanceID) != ""))
}
func (s *PaymentService) getEnabledWebhookProvidersByKey(ctx context.Context, providerKey string) ([]payment.Provider, error) {
diff --git a/backend/internal/service/payment_webhook_provider_test.go b/backend/internal/service/payment_webhook_provider_test.go
index 15b447c2..f12cf691 100644
--- a/backend/internal/service/payment_webhook_provider_test.go
+++ b/backend/internal/service/payment_webhook_provider_test.go
@@ -5,6 +5,7 @@ package service
import (
"context"
"encoding/json"
+ "strconv"
"testing"
"time"
@@ -205,6 +206,72 @@ func TestGetOrderProviderInstanceLeavesProviderKeyMatchUnresolvedWhenTypeNotSupp
require.Nil(t, got)
}
+func TestGetOrderProviderInstanceUsesProviderSnapshotWhenPinnedColumnMissing(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ inst, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeStripe).
+ SetName("stripe-snapshot").
+ SetConfig(encryptWebhookProviderConfig(t, map[string]string{"secretKey": "sk_snapshot"})).
+ SetSupportedTypes("stripe").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ order := &dbent.PaymentOrder{
+ ID: 42,
+ PaymentType: payment.TypeStripe,
+ ProviderSnapshot: map[string]any{
+ "schema_version": 1,
+ "provider_instance_id": strconv.FormatInt(inst.ID, 10),
+ "provider_key": payment.TypeStripe,
+ },
+ }
+
+ svc := &PaymentService{
+ entClient: client,
+ loadBalancer: newWebhookProviderTestLoadBalancer(client),
+ }
+
+ got, err := svc.getOrderProviderInstance(ctx, order)
+ require.NoError(t, err)
+ require.NotNil(t, got)
+ require.Equal(t, inst.ID, got.ID)
+}
+
+func TestGetOrderProviderInstanceRejectsMissingSnapshotInstanceWithoutLegacyFallback(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ _, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeStripe).
+ SetName("stripe-legacy-fallback").
+ SetConfig(encryptWebhookProviderConfig(t, map[string]string{"secretKey": "sk_legacy"})).
+ SetSupportedTypes("stripe").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ order := &dbent.PaymentOrder{
+ ID: 43,
+ PaymentType: payment.TypeStripe,
+ ProviderSnapshot: map[string]any{
+ "schema_version": 1,
+ "provider_instance_id": "999999",
+ "provider_key": payment.TypeStripe,
+ },
+ }
+
+ svc := &PaymentService{
+ entClient: client,
+ loadBalancer: newWebhookProviderTestLoadBalancer(client),
+ }
+
+ got, err := svc.getOrderProviderInstance(ctx, order)
+ require.Nil(t, got)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "provider snapshot instance 999999 is missing")
+}
+
func TestGetWebhookProviderRejectsAmbiguousRegistryFallback(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
@@ -364,3 +431,86 @@ func TestGetWebhookProviderRejectsRegistryFallbackForPinnedOrder(t *testing.T) {
require.Error(t, err)
require.Contains(t, err.Error(), "provider instance")
}
+
+func TestGetWebhookProviderUsesProviderSnapshotBeforeWxpayFallback(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ user, err := client.User.Create().
+ SetEmail("snapshot-webhook@example.com").
+ SetPasswordHash("hash").
+ SetUsername("snapshot-webhook").
+ Save(ctx)
+ require.NoError(t, err)
+
+ wxpayConfigA := encryptWebhookProviderConfig(t, map[string]string{
+ "appId": "wx-app-snapshot-a",
+ "mchId": "mch-snapshot-a",
+ "privateKey": "private-key-snapshot-a",
+ "apiV3Key": webhookProviderTestEncryptionKey,
+ "publicKey": "public-key-snapshot-a",
+ "publicKeyId": "public-key-id-snapshot-a",
+ "certSerial": "cert-serial-snapshot-a",
+ })
+ wxpayConfigB := encryptWebhookProviderConfig(t, map[string]string{
+ "appId": "wx-app-snapshot-b",
+ "mchId": "mch-snapshot-b",
+ "privateKey": "private-key-snapshot-b",
+ "apiV3Key": webhookProviderTestEncryptionKey,
+ "publicKey": "public-key-snapshot-b",
+ "publicKeyId": "public-key-id-snapshot-b",
+ "certSerial": "cert-serial-snapshot-b",
+ })
+ instA, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeWxpay).
+ SetName("wxpay-snapshot-a").
+ SetConfig(wxpayConfigA).
+ SetSupportedTypes("wxpay").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+ _, err = client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeWxpay).
+ SetName("wxpay-snapshot-b").
+ SetConfig(wxpayConfigB).
+ SetSupportedTypes("wxpay").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(66).
+ SetPayAmount(66).
+ SetFeeRate(0).
+ SetRechargeCode("SNAPSHOT-WEBHOOK").
+ SetOutTradeNo("sub2_test_snapshot_webhook_order").
+ SetPaymentType(payment.TypeWxpay).
+ SetPaymentTradeNo("").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusPending).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ SetProviderSnapshot(map[string]any{
+ "schema_version": 1,
+ "provider_instance_id": strconv.FormatInt(instA.ID, 10),
+ "provider_key": payment.TypeWxpay,
+ "payment_mode": "native",
+ }).
+ Save(ctx)
+ require.NoError(t, err)
+
+ svc := &PaymentService{
+ entClient: client,
+ loadBalancer: newWebhookProviderTestLoadBalancer(client),
+ registry: payment.NewRegistry(),
+ providersLoaded: true,
+ }
+
+ providers, err := svc.GetWebhookProviders(ctx, payment.TypeWxpay, "sub2_test_snapshot_webhook_order")
+ require.NoError(t, err)
+ require.Len(t, providers, 1)
+ require.Equal(t, payment.TypeWxpay, providers[0].ProviderKey())
+}
--
GitLab
From 119f784d19bf6528aa45bb14a2ffcb2fe91dd6b2 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 12:57:35 +0800
Subject: [PATCH 134/261] fix: validate wxpay payments against order snapshots
---
backend/internal/payment/provider/wxpay.go | 43 +++++++++++++-
.../internal/payment/provider/wxpay_test.go | 28 +++++++++
backend/internal/payment/types.go | 20 ++++---
.../internal/service/payment_fulfillment.go | 57 ++++++++++++++++++-
.../service/payment_fulfillment_test.go | 43 ++++++++++++++
backend/internal/service/payment_order.go | 26 ++++++++-
.../service/payment_order_lifecycle.go | 2 +-
.../payment_order_provider_snapshot.go | 14 ++++-
.../payment_order_provider_snapshot_test.go | 28 ++++++++-
9 files changed, 239 insertions(+), 22 deletions(-)
diff --git a/backend/internal/payment/provider/wxpay.go b/backend/internal/payment/provider/wxpay.go
index 7d51dff0..30016338 100644
--- a/backend/internal/payment/provider/wxpay.go
+++ b/backend/internal/payment/provider/wxpay.go
@@ -32,6 +32,13 @@ const (
wxpayResultPath = "/payment/result"
)
+const (
+ wxpayMetadataAppID = "appid"
+ wxpayMetadataMerchantID = "mchid"
+ wxpayMetadataCurrency = "currency"
+ wxpayMetadataTradeState = "trade_state"
+)
+
// WeChat Pay create-payment modes.
const (
wxpayModeNative = "native"
@@ -355,6 +362,32 @@ func mapWxState(s string) string {
}
}
+func buildWxpayTransactionMetadata(tx *payments.Transaction) map[string]string {
+ if tx == nil {
+ return nil
+ }
+
+ metadata := map[string]string{}
+ if appID := wxSV(tx.Appid); appID != "" {
+ metadata[wxpayMetadataAppID] = appID
+ }
+ if merchantID := wxSV(tx.Mchid); merchantID != "" {
+ metadata[wxpayMetadataMerchantID] = merchantID
+ }
+ if tradeState := wxSV(tx.TradeState); tradeState != "" {
+ metadata[wxpayMetadataTradeState] = tradeState
+ }
+ if tx.Amount != nil {
+ if currency := wxSV(tx.Amount.Currency); currency != "" {
+ metadata[wxpayMetadataCurrency] = currency
+ }
+ }
+ if len(metadata) == 0 {
+ return nil
+ }
+ return metadata
+}
+
func (w *Wxpay) QueryOrder(ctx context.Context, tradeNo string) (*payment.QueryOrderResponse, error) {
c, err := w.ensureClient()
if err != nil {
@@ -379,7 +412,13 @@ func (w *Wxpay) QueryOrder(ctx context.Context, tradeNo string) (*payment.QueryO
if tx.SuccessTime != nil {
pa = *tx.SuccessTime
}
- return &payment.QueryOrderResponse{TradeNo: id, Status: mapWxState(wxSV(tx.TradeState)), Amount: amt, PaidAt: pa}, nil
+ return &payment.QueryOrderResponse{
+ TradeNo: id,
+ Status: mapWxState(wxSV(tx.TradeState)),
+ Amount: amt,
+ PaidAt: pa,
+ Metadata: buildWxpayTransactionMetadata(tx),
+ }, nil
}
func (w *Wxpay) VerifyNotification(ctx context.Context, rawBody string, headers map[string]string) (*payment.PaymentNotification, error) {
@@ -411,7 +450,7 @@ func (w *Wxpay) VerifyNotification(ctx context.Context, rawBody string, headers
}
return &payment.PaymentNotification{
TradeNo: wxSV(tx.TransactionId), OrderID: wxSV(tx.OutTradeNo),
- Amount: amt, Status: st, RawData: rawBody,
+ Amount: amt, Status: st, RawData: rawBody, Metadata: buildWxpayTransactionMetadata(&tx),
}, nil
}
diff --git a/backend/internal/payment/provider/wxpay_test.go b/backend/internal/payment/provider/wxpay_test.go
index 6d0006be..b3f4f648 100644
--- a/backend/internal/payment/provider/wxpay_test.go
+++ b/backend/internal/payment/provider/wxpay_test.go
@@ -10,6 +10,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/wechatpay-apiv3/wechatpay-go/core"
+ "github.com/wechatpay-apiv3/wechatpay-go/services/payments"
"github.com/wechatpay-apiv3/wechatpay-go/services/payments/h5"
"github.com/wechatpay-apiv3/wechatpay-go/services/payments/jsapi"
"github.com/wechatpay-apiv3/wechatpay-go/services/payments/native"
@@ -102,6 +103,33 @@ func TestWxSV(t *testing.T) {
}
}
+func TestBuildWxpayTransactionMetadata(t *testing.T) {
+ t.Parallel()
+
+ tx := &payments.Transaction{
+ Appid: strPtr("wx-app-id"),
+ Mchid: strPtr("mch-id"),
+ TradeState: strPtr(wxpayTradeStateSuccess),
+ Amount: &payments.Amount{
+ Currency: strPtr(wxpayCurrency),
+ },
+ }
+
+ metadata := buildWxpayTransactionMetadata(tx)
+ if metadata[wxpayMetadataAppID] != "wx-app-id" {
+ t.Fatalf("appid = %q", metadata[wxpayMetadataAppID])
+ }
+ if metadata[wxpayMetadataMerchantID] != "mch-id" {
+ t.Fatalf("mchid = %q", metadata[wxpayMetadataMerchantID])
+ }
+ if metadata[wxpayMetadataCurrency] != wxpayCurrency {
+ t.Fatalf("currency = %q", metadata[wxpayMetadataCurrency])
+ }
+ if metadata[wxpayMetadataTradeState] != wxpayTradeStateSuccess {
+ t.Fatalf("trade_state = %q", metadata[wxpayMetadataTradeState])
+ }
+}
+
func strPtr(s string) *string {
return &s
}
diff --git a/backend/internal/payment/types.go b/backend/internal/payment/types.go
index bb125247..29abf82b 100644
--- a/backend/internal/payment/types.go
+++ b/backend/internal/payment/types.go
@@ -149,19 +149,21 @@ type CreatePaymentResponse struct {
// QueryOrderResponse describes the payment status from the upstream provider.
type QueryOrderResponse struct {
- TradeNo string
- Status string // "pending", "paid", "failed", "refunded"
- Amount float64 // Amount in CNY
- PaidAt string // RFC3339 timestamp or empty
+ TradeNo string
+ Status string // "pending", "paid", "failed", "refunded"
+ Amount float64 // Amount in CNY
+ PaidAt string // RFC3339 timestamp or empty
+ Metadata map[string]string
}
// PaymentNotification is the parsed result of a webhook/notify callback.
type PaymentNotification struct {
- TradeNo string
- OrderID string
- Amount float64
- Status string // "success" or "failed"
- RawData string // Raw notification body for audit
+ TradeNo string
+ OrderID string
+ Amount float64
+ Status string // "success" or "failed"
+ RawData string // Raw notification body for audit
+ Metadata map[string]string
}
// RefundRequest contains the parameters for requesting a refund.
diff --git a/backend/internal/service/payment_fulfillment.go b/backend/internal/service/payment_fulfillment.go
index 9cb03cca..7bde03c8 100644
--- a/backend/internal/service/payment_fulfillment.go
+++ b/backend/internal/service/payment_fulfillment.go
@@ -28,14 +28,14 @@ func (s *PaymentService) HandlePaymentNotification(ctx context.Context, n *payme
// Fallback: try legacy format (sub2_N where N is DB ID)
trimmed := strings.TrimPrefix(n.OrderID, orderIDPrefix)
if oid, parseErr := strconv.ParseInt(trimmed, 10, 64); parseErr == nil {
- return s.confirmPayment(ctx, oid, n.TradeNo, n.Amount, pk)
+ return s.confirmPayment(ctx, oid, n.TradeNo, n.Amount, pk, n.Metadata)
}
return fmt.Errorf("order not found for out_trade_no: %s", n.OrderID)
}
- return s.confirmPayment(ctx, order.ID, n.TradeNo, n.Amount, pk)
+ return s.confirmPayment(ctx, order.ID, n.TradeNo, n.Amount, pk, n.Metadata)
}
-func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo string, paid float64, pk string) error {
+func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo string, paid float64, pk string, metadata map[string]string) error {
o, err := s.entClient.PaymentOrder.Get(ctx, oid)
if err != nil {
slog.Error("order not found", "orderID", oid)
@@ -54,6 +54,13 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo
})
return fmt.Errorf("provider mismatch: expected %s, got %s", expectedProviderKey, pk)
}
+ if err := validateProviderNotificationMetadata(o, pk, metadata); err != nil {
+ s.writeAuditLog(ctx, o.ID, "PAYMENT_PROVIDER_METADATA_MISMATCH", pk, map[string]any{
+ "detail": err.Error(),
+ "tradeNo": tradeNo,
+ })
+ return err
+ }
// Skip amount check when paid=0 (e.g. QueryOrder doesn't return amount).
// Also skip if paid is NaN/Inf (malformed provider data).
if paid > 0 && !math.IsNaN(paid) && !math.IsInf(paid, 0) {
@@ -69,6 +76,50 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo
return s.toPaid(ctx, o, tradeNo, paid, pk)
}
+func validateProviderNotificationMetadata(order *dbent.PaymentOrder, providerKey string, metadata map[string]string) error {
+ if order == nil || len(metadata) == 0 || !strings.EqualFold(strings.TrimSpace(providerKey), payment.TypeWxpay) {
+ return nil
+ }
+
+ snapshot := psOrderProviderSnapshot(order)
+ if snapshot == nil {
+ return nil
+ }
+
+ if expected := strings.TrimSpace(snapshot.MerchantAppID); expected != "" {
+ actual := strings.TrimSpace(metadata["appid"])
+ if actual == "" {
+ return fmt.Errorf("wxpay notification missing appid")
+ }
+ if !strings.EqualFold(expected, actual) {
+ return fmt.Errorf("wxpay appid mismatch: expected %s, got %s", expected, actual)
+ }
+ }
+ if expected := strings.TrimSpace(snapshot.MerchantID); expected != "" {
+ actual := strings.TrimSpace(metadata["mchid"])
+ if actual == "" {
+ return fmt.Errorf("wxpay notification missing mchid")
+ }
+ if !strings.EqualFold(expected, actual) {
+ return fmt.Errorf("wxpay mchid mismatch: expected %s, got %s", expected, actual)
+ }
+ }
+ if expected := strings.TrimSpace(snapshot.Currency); expected != "" {
+ actual := strings.ToUpper(strings.TrimSpace(metadata["currency"]))
+ if actual == "" {
+ return fmt.Errorf("wxpay notification missing currency")
+ }
+ if !strings.EqualFold(expected, actual) {
+ return fmt.Errorf("wxpay currency mismatch: expected %s, got %s", expected, actual)
+ }
+ }
+ if actual := strings.TrimSpace(metadata["trade_state"]); actual != "" && !strings.EqualFold(actual, "SUCCESS") {
+ return fmt.Errorf("wxpay trade_state mismatch: expected SUCCESS, got %s", actual)
+ }
+
+ return nil
+}
+
func expectedNotificationProviderKey(registry *payment.Registry, orderPaymentType string, orderProviderKey string, instanceProviderKey string) string {
if key := strings.TrimSpace(instanceProviderKey); key != "" {
return key
diff --git a/backend/internal/service/payment_fulfillment_test.go b/backend/internal/service/payment_fulfillment_test.go
index 3ce82973..8883d3b8 100644
--- a/backend/internal/service/payment_fulfillment_test.go
+++ b/backend/internal/service/payment_fulfillment_test.go
@@ -264,3 +264,46 @@ func TestExpectedNotificationProviderKeyForOrderUsesSnapshotProviderKey(t *testi
expectedNotificationProviderKeyForOrder(registry, order, ""),
)
}
+
+func TestValidateProviderNotificationMetadataRejectsWxpaySnapshotMismatch(t *testing.T) {
+ t.Parallel()
+
+ order := &dbent.PaymentOrder{
+ PaymentType: payment.TypeWxpay,
+ ProviderSnapshot: map[string]any{
+ "schema_version": 1,
+ "merchant_app_id": "wx-app-expected",
+ "merchant_id": "mch-expected",
+ "currency": "CNY",
+ },
+ }
+
+ err := validateProviderNotificationMetadata(order, payment.TypeWxpay, map[string]string{
+ "appid": "wx-app-other",
+ "mchid": "mch-expected",
+ "currency": "CNY",
+ "trade_state": "SUCCESS",
+ })
+ assert.ErrorContains(t, err, "wxpay appid mismatch")
+}
+
+func TestValidateProviderNotificationMetadataAllowsLegacyOrdersWithoutSnapshotFields(t *testing.T) {
+ t.Parallel()
+
+ order := &dbent.PaymentOrder{
+ PaymentType: payment.TypeWxpay,
+ ProviderSnapshot: map[string]any{
+ "schema_version": 1,
+ "provider_instance_id": "9",
+ "provider_key": payment.TypeWxpay,
+ },
+ }
+
+ err := validateProviderNotificationMetadata(order, payment.TypeWxpay, map[string]string{
+ "appid": "wx-app-runtime",
+ "mchid": "mch-runtime",
+ "currency": "CNY",
+ "trade_state": "SUCCESS",
+ })
+ assert.NoError(t, err)
+}
diff --git a/backend/internal/service/payment_order.go b/backend/internal/service/payment_order.go
index 1f01bc11..6ee490a8 100644
--- a/backend/internal/service/payment_order.go
+++ b/backend/internal/service/payment_order.go
@@ -139,7 +139,7 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq
tm = defaultOrderTimeoutMin
}
exp := time.Now().Add(time.Duration(tm) * time.Minute)
- providerSnapshot := buildPaymentOrderProviderSnapshot(sel)
+ providerSnapshot := buildPaymentOrderProviderSnapshot(sel, req)
selectedInstanceID := ""
selectedProviderKey := ""
if sel != nil {
@@ -208,13 +208,13 @@ func (s *PaymentService) checkPendingLimit(ctx context.Context, tx *dbent.Tx, us
return nil
}
-func buildPaymentOrderProviderSnapshot(sel *payment.InstanceSelection) map[string]any {
+func buildPaymentOrderProviderSnapshot(sel *payment.InstanceSelection, req CreateOrderRequest) map[string]any {
if sel == nil {
return nil
}
snapshot := map[string]any{}
- snapshot["schema_version"] = 1
+ snapshot["schema_version"] = 2
instanceID := strings.TrimSpace(sel.InstanceID)
if instanceID != "" {
@@ -231,12 +231,32 @@ func buildPaymentOrderProviderSnapshot(sel *payment.InstanceSelection) map[strin
snapshot["payment_mode"] = paymentMode
}
+ if providerKey == payment.TypeWxpay {
+ if merchantAppID := paymentOrderSnapshotWxpayAppID(sel, req); merchantAppID != "" {
+ snapshot["merchant_app_id"] = merchantAppID
+ }
+ if merchantID := strings.TrimSpace(sel.Config["mchId"]); merchantID != "" {
+ snapshot["merchant_id"] = merchantID
+ }
+ snapshot["currency"] = "CNY"
+ }
+
if len(snapshot) == 1 {
return nil
}
return snapshot
}
+func paymentOrderSnapshotWxpayAppID(sel *payment.InstanceSelection, req CreateOrderRequest) string {
+ if sel == nil || strings.TrimSpace(sel.ProviderKey) != payment.TypeWxpay {
+ return ""
+ }
+ if strings.TrimSpace(req.OpenID) != "" {
+ return strings.TrimSpace(provider.ResolveWxpayJSAPIAppID(sel.Config))
+ }
+ return strings.TrimSpace(sel.Config["appId"])
+}
+
func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, userID int64, amount, limit float64) error {
if limit <= 0 {
return nil
diff --git a/backend/internal/service/payment_order_lifecycle.go b/backend/internal/service/payment_order_lifecycle.go
index 1564c36d..c11baac1 100644
--- a/backend/internal/service/payment_order_lifecycle.go
+++ b/backend/internal/service/payment_order_lifecycle.go
@@ -163,7 +163,7 @@ func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) s
}
notificationTradeNo = upstreamTradeNo
}
- if err := s.HandlePaymentNotification(ctx, &payment.PaymentNotification{TradeNo: notificationTradeNo, OrderID: o.OutTradeNo, Amount: resp.Amount, Status: payment.ProviderStatusSuccess}, prov.ProviderKey()); err != nil {
+ if err := s.HandlePaymentNotification(ctx, &payment.PaymentNotification{TradeNo: notificationTradeNo, OrderID: o.OutTradeNo, Amount: resp.Amount, Status: payment.ProviderStatusSuccess, Metadata: resp.Metadata}, prov.ProviderKey()); err != nil {
slog.Error("fulfillment failed during checkPaid", "orderID", o.ID, "error", err)
// Still return already_paid — order was paid, fulfillment can be retried
}
diff --git a/backend/internal/service/payment_order_provider_snapshot.go b/backend/internal/service/payment_order_provider_snapshot.go
index 9a0aa106..31a790c7 100644
--- a/backend/internal/service/payment_order_provider_snapshot.go
+++ b/backend/internal/service/payment_order_provider_snapshot.go
@@ -15,6 +15,9 @@ type paymentOrderProviderSnapshot struct {
ProviderInstanceID string
ProviderKey string
PaymentMode string
+ MerchantAppID string
+ MerchantID string
+ Currency string
}
func psOrderProviderSnapshot(order *dbent.PaymentOrder) *paymentOrderProviderSnapshot {
@@ -27,8 +30,17 @@ func psOrderProviderSnapshot(order *dbent.PaymentOrder) *paymentOrderProviderSna
ProviderInstanceID: psSnapshotStringValue(order.ProviderSnapshot["provider_instance_id"]),
ProviderKey: psSnapshotStringValue(order.ProviderSnapshot["provider_key"]),
PaymentMode: psSnapshotStringValue(order.ProviderSnapshot["payment_mode"]),
+ MerchantAppID: psSnapshotStringValue(order.ProviderSnapshot["merchant_app_id"]),
+ MerchantID: psSnapshotStringValue(order.ProviderSnapshot["merchant_id"]),
+ Currency: psSnapshotStringValue(order.ProviderSnapshot["currency"]),
}
- if snapshot.SchemaVersion == 0 && snapshot.ProviderInstanceID == "" && snapshot.ProviderKey == "" && snapshot.PaymentMode == "" {
+ if snapshot.SchemaVersion == 0 &&
+ snapshot.ProviderInstanceID == "" &&
+ snapshot.ProviderKey == "" &&
+ snapshot.PaymentMode == "" &&
+ snapshot.MerchantAppID == "" &&
+ snapshot.MerchantID == "" &&
+ snapshot.Currency == "" {
return nil
}
return snapshot
diff --git a/backend/internal/service/payment_order_provider_snapshot_test.go b/backend/internal/service/payment_order_provider_snapshot_test.go
index c75566bc..bc6666a8 100644
--- a/backend/internal/service/payment_order_provider_snapshot_test.go
+++ b/backend/internal/service/payment_order_provider_snapshot_test.go
@@ -26,18 +26,21 @@ func TestBuildPaymentOrderProviderSnapshot_ExcludesSensitiveConfig(t *testing.T)
},
}
- snapshot := buildPaymentOrderProviderSnapshot(sel)
+ snapshot := buildPaymentOrderProviderSnapshot(sel, CreateOrderRequest{})
require.Equal(t, map[string]any{
- "schema_version": 1,
+ "schema_version": 2,
"provider_instance_id": "12",
"provider_key": payment.TypeWxpay,
"payment_mode": "popup",
+ "merchant_app_id": "wx-app-id",
+ "currency": "CNY",
}, snapshot)
require.NotContains(t, snapshot, "config")
require.NotContains(t, snapshot, "privateKey")
require.NotContains(t, snapshot, "apiV3Key")
require.NotContains(t, snapshot, "supported_types")
require.NotContains(t, snapshot, "instance_name")
+ require.NotContains(t, snapshot, "merchant_id")
}
func TestCreateOrderInTx_WritesProviderSnapshot(t *testing.T) {
@@ -98,7 +101,7 @@ func TestCreateOrderInTx_WritesProviderSnapshot(t *testing.T) {
require.NoError(t, err)
require.Equal(t, strconv.FormatInt(instance.ID, 10), valueOrEmpty(order.ProviderInstanceID))
require.Equal(t, payment.TypeAlipay, valueOrEmpty(order.ProviderKey))
- require.Equal(t, float64(1), order.ProviderSnapshot["schema_version"])
+ require.Equal(t, float64(2), order.ProviderSnapshot["schema_version"])
require.Equal(t, strconv.FormatInt(instance.ID, 10), order.ProviderSnapshot["provider_instance_id"])
require.Equal(t, payment.TypeAlipay, order.ProviderSnapshot["provider_key"])
require.Equal(t, "redirect", order.ProviderSnapshot["payment_mode"])
@@ -108,6 +111,25 @@ func TestCreateOrderInTx_WritesProviderSnapshot(t *testing.T) {
require.NotContains(t, order.ProviderSnapshot, "instance_name")
}
+func TestBuildPaymentOrderProviderSnapshot_UsesWxpayJSAPIAppIDForOpenIDOrders(t *testing.T) {
+ t.Parallel()
+
+ snapshot := buildPaymentOrderProviderSnapshot(&payment.InstanceSelection{
+ InstanceID: "88",
+ ProviderKey: payment.TypeWxpay,
+ Config: map[string]string{
+ "appId": "wx-open-app",
+ "mpAppId": "wx-mp-app",
+ "mchId": "mch-88",
+ },
+ PaymentMode: "jsapi",
+ }, CreateOrderRequest{OpenID: "openid-123"})
+
+ require.Equal(t, "wx-mp-app", snapshot["merchant_app_id"])
+ require.Equal(t, "mch-88", snapshot["merchant_id"])
+ require.Equal(t, "CNY", snapshot["currency"])
+}
+
func valueOrEmpty(v *string) string {
if v == nil {
return ""
--
GitLab
From 276ce052a30eab0148b95cc7f22e1feb955f75b5 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 13:01:21 +0800
Subject: [PATCH 135/261] fix: align payment recovery query refs and resume
authority
---
.../service/payment_order_lifecycle.go | 61 +++++++++++--
.../service/payment_order_lifecycle_test.go | 89 +++++++++++++++++++
.../internal/service/payment_resume_lookup.go | 15 +++-
.../service/payment_resume_lookup_test.go | 57 ++++++++++++
4 files changed, 212 insertions(+), 10 deletions(-)
diff --git a/backend/internal/service/payment_order_lifecycle.go b/backend/internal/service/payment_order_lifecycle.go
index c11baac1..a192f599 100644
--- a/backend/internal/service/payment_order_lifecycle.go
+++ b/backend/internal/service/payment_order_lifecycle.go
@@ -5,6 +5,7 @@ import (
"fmt"
"log/slog"
"strconv"
+ "strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
@@ -139,20 +140,18 @@ func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) s
if err != nil {
return ""
}
- // Use OutTradeNo as fallback when PaymentTradeNo is empty
- // (e.g. EasyPay popup mode where trade_no arrives only via notify callback)
- tradeNo := o.PaymentTradeNo
- if tradeNo == "" {
- tradeNo = o.OutTradeNo
+ queryRef := paymentOrderQueryReference(o, prov)
+ if queryRef == "" {
+ return ""
}
- resp, err := prov.QueryOrder(ctx, tradeNo)
+ resp, err := prov.QueryOrder(ctx, queryRef)
if err != nil {
slog.Warn("query upstream failed", "orderID", o.ID, "error", err)
return ""
}
if resp.Status == payment.ProviderStatusPaid {
notificationTradeNo := o.PaymentTradeNo
- if upstreamTradeNo := resp.TradeNo; upstreamTradeNo != "" && upstreamTradeNo != notificationTradeNo {
+ if upstreamTradeNo := strings.TrimSpace(resp.TradeNo); paymentOrderShouldPersistUpstreamTradeNo(queryRef, upstreamTradeNo, notificationTradeNo) {
if _, updateErr := s.entClient.PaymentOrder.Update().
Where(paymentorder.IDEQ(o.ID)).
SetPaymentTradeNo(upstreamTradeNo).
@@ -170,11 +169,57 @@ func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) s
return checkPaidResultAlreadyPaid
}
if cp, ok := prov.(payment.CancelableProvider); ok {
- _ = cp.CancelPayment(ctx, tradeNo)
+ _ = cp.CancelPayment(ctx, queryRef)
}
return ""
}
+func paymentOrderQueryReference(order *dbent.PaymentOrder, prov payment.Provider) string {
+ if order == nil {
+ return ""
+ }
+
+ providerKey := ""
+ if prov != nil {
+ providerKey = strings.TrimSpace(prov.ProviderKey())
+ }
+ if providerKey == "" {
+ if snapshot := psOrderProviderSnapshot(order); snapshot != nil {
+ providerKey = strings.TrimSpace(snapshot.ProviderKey)
+ }
+ }
+ if providerKey == "" {
+ providerKey = strings.TrimSpace(psStringValue(order.ProviderKey))
+ }
+ if providerKey == "" {
+ providerKey = strings.TrimSpace(order.PaymentType)
+ }
+
+ switch payment.GetBasePaymentType(providerKey) {
+ case payment.TypeAlipay, payment.TypeEasyPay, payment.TypeWxpay:
+ return strings.TrimSpace(order.OutTradeNo)
+ default:
+ if tradeNo := strings.TrimSpace(order.PaymentTradeNo); tradeNo != "" {
+ return tradeNo
+ }
+ return strings.TrimSpace(order.OutTradeNo)
+ }
+}
+
+func paymentOrderShouldPersistUpstreamTradeNo(queryRef, upstreamTradeNo, currentTradeNo string) bool {
+ upstreamTradeNo = strings.TrimSpace(upstreamTradeNo)
+ if upstreamTradeNo == "" {
+ return false
+ }
+ if strings.EqualFold(upstreamTradeNo, strings.TrimSpace(currentTradeNo)) {
+ return false
+ }
+ if strings.EqualFold(upstreamTradeNo, strings.TrimSpace(queryRef)) {
+ return false
+ }
+ return true
+}
+
// VerifyOrderByOutTradeNo actively queries the upstream provider to check
// if a payment was made, and processes it if so. This handles the case where
// the provider's notify callback was missed (e.g. EasyPay popup mode).
diff --git a/backend/internal/service/payment_order_lifecycle_test.go b/backend/internal/service/payment_order_lifecycle_test.go
index 3d4773a4..3c6c65a5 100644
--- a/backend/internal/service/payment_order_lifecycle_test.go
+++ b/backend/internal/service/payment_order_lifecycle_test.go
@@ -234,6 +234,95 @@ func TestVerifyOrderByOutTradeNoBackfillsTradeNoFromPaidQuery(t *testing.T) {
require.Equal(t, user.ID, redeemRepo.useCalls[0].userID)
}
+func TestVerifyOrderByOutTradeNoUsesOutTradeNoWhenPaymentTradeNoAlreadyExistsForAlipay(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentOrderLifecycleTestClient(t)
+
+ user, err := client.User.Create().
+ SetEmail("checkpaid-existing-trade@example.com").
+ SetPasswordHash("hash").
+ SetUsername("checkpaid-existing-trade-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(88).
+ SetPayAmount(88).
+ SetFeeRate(0).
+ SetRechargeCode("CHECKPAID-EXISTING-TRADE-NO").
+ SetOutTradeNo("sub2_checkpaid_use_out_trade_no").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("upstream-trade-existing").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusPending).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ Save(ctx)
+ require.NoError(t, err)
+
+ userRepo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: user.ID,
+ Email: user.Email,
+ Username: user.Username,
+ Balance: 0,
+ },
+ }
+ userRepo.updateBalanceFn = func(ctx context.Context, id int64, amount float64) error {
+ require.Equal(t, user.ID, id)
+ if userRepo.getByIDUser != nil {
+ userRepo.getByIDUser.Balance += amount
+ }
+ return nil
+ }
+ redeemRepo := &paymentOrderLifecycleRedeemRepo{
+ codesByCode: map[string]*RedeemCode{
+ order.RechargeCode: {
+ ID: 1,
+ Code: order.RechargeCode,
+ Type: RedeemTypeBalance,
+ Value: order.Amount,
+ Status: StatusUnused,
+ },
+ },
+ }
+ redeemService := NewRedeemService(
+ redeemRepo,
+ userRepo,
+ nil,
+ nil,
+ nil,
+ client,
+ nil,
+ )
+ registry := payment.NewRegistry()
+ provider := &paymentOrderLifecycleQueryProvider{
+ resp: &payment.QueryOrderResponse{
+ TradeNo: "upstream-trade-existing",
+ Status: payment.ProviderStatusPaid,
+ Amount: 88,
+ },
+ }
+ registry.Register(provider)
+
+ svc := &PaymentService{
+ entClient: client,
+ registry: registry,
+ redeemService: redeemService,
+ userRepo: userRepo,
+ providersLoaded: true,
+ }
+
+ got, err := svc.VerifyOrderByOutTradeNo(ctx, order.OutTradeNo, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, order.OutTradeNo, provider.lastQueryTradeNo)
+ require.Equal(t, "upstream-trade-existing", got.PaymentTradeNo)
+}
+
func newPaymentOrderLifecycleTestClient(t *testing.T) *dbent.Client {
t.Helper()
diff --git a/backend/internal/service/payment_resume_lookup.go b/backend/internal/service/payment_resume_lookup.go
index 048e489a..05626aa6 100644
--- a/backend/internal/service/payment_resume_lookup.go
+++ b/backend/internal/service/payment_resume_lookup.go
@@ -21,10 +21,21 @@ func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token
if claims.UserID > 0 && order.UserID != claims.UserID {
return nil, fmt.Errorf("resume token user mismatch")
}
- if claims.ProviderInstanceID != "" && strings.TrimSpace(psStringValue(order.ProviderInstanceID)) != claims.ProviderInstanceID {
+ snapshot := psOrderProviderSnapshot(order)
+ orderProviderInstanceID := strings.TrimSpace(psStringValue(order.ProviderInstanceID))
+ orderProviderKey := strings.TrimSpace(psStringValue(order.ProviderKey))
+ if snapshot != nil {
+ if snapshot.ProviderInstanceID != "" {
+ orderProviderInstanceID = snapshot.ProviderInstanceID
+ }
+ if snapshot.ProviderKey != "" {
+ orderProviderKey = snapshot.ProviderKey
+ }
+ }
+ if claims.ProviderInstanceID != "" && orderProviderInstanceID != claims.ProviderInstanceID {
return nil, fmt.Errorf("resume token provider instance mismatch")
}
- if claims.ProviderKey != "" && strings.TrimSpace(psStringValue(order.ProviderKey)) != claims.ProviderKey {
+ if claims.ProviderKey != "" && orderProviderKey != claims.ProviderKey {
return nil, fmt.Errorf("resume token provider key mismatch")
}
if claims.PaymentType != "" && strings.TrimSpace(order.PaymentType) != claims.PaymentType {
diff --git a/backend/internal/service/payment_resume_lookup_test.go b/backend/internal/service/payment_resume_lookup_test.go
index 3a50b5bc..946e7aa2 100644
--- a/backend/internal/service/payment_resume_lookup_test.go
+++ b/backend/internal/service/payment_resume_lookup_test.go
@@ -146,6 +146,63 @@ func TestGetPublicOrderByResumeTokenRejectsSnapshotMismatch(t *testing.T) {
require.Contains(t, err.Error(), "resume token")
}
+func TestGetPublicOrderByResumeTokenUsesSnapshotAuthorityWhenColumnsDiffer(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ user, err := client.User.Create().
+ SetEmail("resume-snapshot-authority@example.com").
+ SetPasswordHash("hash").
+ SetUsername("resume-snapshot-authority-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(88).
+ SetPayAmount(88).
+ SetFeeRate(0).
+ SetRechargeCode("RESUME-SNAPSHOT-AUTHORITY").
+ SetOutTradeNo("sub2_resume_snapshot_authority").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("trade-snapshot-authority").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusPending).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ SetProviderInstanceID("legacy-column-instance").
+ SetProviderKey(payment.TypeAlipay).
+ SetProviderSnapshot(map[string]any{
+ "schema_version": 2,
+ "provider_instance_id": "snapshot-instance",
+ "provider_key": payment.TypeEasyPay,
+ }).
+ Save(ctx)
+ require.NoError(t, err)
+
+ resumeSvc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
+ token, err := resumeSvc.CreateToken(ResumeTokenClaims{
+ OrderID: order.ID,
+ UserID: user.ID,
+ ProviderInstanceID: "snapshot-instance",
+ ProviderKey: payment.TypeEasyPay,
+ PaymentType: payment.TypeAlipay,
+ CanonicalReturnURL: "https://app.example.com/payment/result",
+ })
+ require.NoError(t, err)
+
+ svc := &PaymentService{
+ entClient: client,
+ resumeService: resumeSvc,
+ }
+
+ got, err := svc.GetPublicOrderByResumeToken(ctx, token)
+ require.NoError(t, err)
+ require.Equal(t, order.ID, got.ID)
+}
+
func TestGetPublicOrderByResumeTokenChecksUpstreamForPendingOrder(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
--
GitLab
From 64e401e22475336561966038968ffd0a9fba3a51 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 13:03:53 +0800
Subject: [PATCH 136/261] fix: tighten payment legacy fallback paths
---
.../internal/service/payment_fulfillment.go | 25 ++++++++++--
.../service/payment_fulfillment_test.go | 14 +++++++
.../service/payment_order_lifecycle.go | 38 +++++++++++++++++++
.../service/payment_order_lifecycle_test.go | 37 ++++++++++++++++++
4 files changed, 111 insertions(+), 3 deletions(-)
diff --git a/backend/internal/service/payment_fulfillment.go b/backend/internal/service/payment_fulfillment.go
index 7bde03c8..423ed80f 100644
--- a/backend/internal/service/payment_fulfillment.go
+++ b/backend/internal/service/payment_fulfillment.go
@@ -25,9 +25,9 @@ func (s *PaymentService) HandlePaymentNotification(ctx context.Context, n *payme
// Look up order by out_trade_no (the external order ID we sent to the provider)
order, err := s.entClient.PaymentOrder.Query().Where(paymentorder.OutTradeNo(n.OrderID)).Only(ctx)
if err != nil {
- // Fallback: try legacy format (sub2_N where N is DB ID)
- trimmed := strings.TrimPrefix(n.OrderID, orderIDPrefix)
- if oid, parseErr := strconv.ParseInt(trimmed, 10, 64); parseErr == nil {
+ // Fallback only for true legacy "sub2_N" DB-ID payloads when the
+ // current out_trade_no lookup genuinely did not find an order.
+ if oid, ok := parseLegacyPaymentOrderID(n.OrderID, err); ok {
return s.confirmPayment(ctx, oid, n.TradeNo, n.Amount, pk, n.Metadata)
}
return fmt.Errorf("order not found for out_trade_no: %s", n.OrderID)
@@ -35,6 +35,25 @@ func (s *PaymentService) HandlePaymentNotification(ctx context.Context, n *payme
return s.confirmPayment(ctx, order.ID, n.TradeNo, n.Amount, pk, n.Metadata)
}
+func parseLegacyPaymentOrderID(orderID string, lookupErr error) (int64, bool) {
+ if !dbent.IsNotFound(lookupErr) {
+ return 0, false
+ }
+ orderID = strings.TrimSpace(orderID)
+ if !strings.HasPrefix(orderID, orderIDPrefix) {
+ return 0, false
+ }
+ trimmed := strings.TrimPrefix(orderID, orderIDPrefix)
+ if trimmed == "" || trimmed == orderID {
+ return 0, false
+ }
+ oid, err := strconv.ParseInt(trimmed, 10, 64)
+ if err != nil || oid <= 0 {
+ return 0, false
+ }
+ return oid, true
+}
+
func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo string, paid float64, pk string, metadata map[string]string) error {
o, err := s.entClient.PaymentOrder.Get(ctx, oid)
if err != nil {
diff --git a/backend/internal/service/payment_fulfillment_test.go b/backend/internal/service/payment_fulfillment_test.go
index 8883d3b8..d70f8946 100644
--- a/backend/internal/service/payment_fulfillment_test.go
+++ b/backend/internal/service/payment_fulfillment_test.go
@@ -307,3 +307,17 @@ func TestValidateProviderNotificationMetadataAllowsLegacyOrdersWithoutSnapshotFi
})
assert.NoError(t, err)
}
+
+func TestParseLegacyPaymentOrderID(t *testing.T) {
+ t.Parallel()
+
+ oid, ok := parseLegacyPaymentOrderID("sub2_42", &dbent.NotFoundError{})
+ assert.True(t, ok)
+ assert.EqualValues(t, 42, oid)
+
+ _, ok = parseLegacyPaymentOrderID("42", &dbent.NotFoundError{})
+ assert.False(t, ok)
+
+ _, ok = parseLegacyPaymentOrderID("sub2_42", errors.New("db down"))
+ assert.False(t, ok)
+}
diff --git a/backend/internal/service/payment_order_lifecycle.go b/backend/internal/service/payment_order_lifecycle.go
index a192f599..ccab7c11 100644
--- a/backend/internal/service/payment_order_lifecycle.go
+++ b/backend/internal/service/payment_order_lifecycle.go
@@ -292,10 +292,48 @@ func (s *PaymentService) getOrderProvider(ctx context.Context, o *dbent.PaymentO
if inst != nil {
return s.createProviderFromInstance(ctx, inst)
}
+ if !paymentOrderAllowsRegistryFallback(o) {
+ return nil, fmt.Errorf("order %d provider instance is unresolved", o.ID)
+ }
+ providerKey := paymentOrderFallbackProviderKey(s.registry, o)
+ if providerKey == "" {
+ return nil, fmt.Errorf("order %d provider fallback key is missing", o.ID)
+ }
+ if !s.webhookRegistryFallbackAllowed(ctx, providerKey) {
+ return nil, fmt.Errorf("order %d provider fallback is ambiguous for %s", o.ID, providerKey)
+ }
s.EnsureProviders(ctx)
return s.registry.GetProvider(o.PaymentType)
}
+func paymentOrderAllowsRegistryFallback(order *dbent.PaymentOrder) bool {
+ if order == nil {
+ return false
+ }
+ if psOrderProviderSnapshot(order) != nil {
+ return false
+ }
+ if strings.TrimSpace(psStringValue(order.ProviderInstanceID)) != "" {
+ return false
+ }
+ if strings.TrimSpace(psStringValue(order.ProviderKey)) != "" {
+ return false
+ }
+ return true
+}
+
+func paymentOrderFallbackProviderKey(registry *payment.Registry, order *dbent.PaymentOrder) string {
+ if order == nil {
+ return ""
+ }
+ if registry != nil {
+ if key := strings.TrimSpace(registry.GetProviderKey(payment.PaymentType(order.PaymentType))); key != "" {
+ return key
+ }
+ }
+ return strings.TrimSpace(payment.GetBasePaymentType(strings.TrimSpace(order.PaymentType)))
+}
+
func (s *PaymentService) createProviderFromInstance(ctx context.Context, inst *dbent.PaymentProviderInstance) (payment.Provider, error) {
if inst == nil {
return nil, fmt.Errorf("payment provider instance is missing")
diff --git a/backend/internal/service/payment_order_lifecycle_test.go b/backend/internal/service/payment_order_lifecycle_test.go
index 3c6c65a5..39993a2f 100644
--- a/backend/internal/service/payment_order_lifecycle_test.go
+++ b/backend/internal/service/payment_order_lifecycle_test.go
@@ -323,6 +323,43 @@ func TestVerifyOrderByOutTradeNoUsesOutTradeNoWhenPaymentTradeNoAlreadyExistsFor
require.Equal(t, "upstream-trade-existing", got.PaymentTradeNo)
}
+func TestPaymentOrderAllowsRegistryFallbackOnlyForLegacyOrdersWithoutPinnedProviderState(t *testing.T) {
+ t.Parallel()
+
+ require.True(t, paymentOrderAllowsRegistryFallback(&dbent.PaymentOrder{
+ PaymentType: payment.TypeAlipay,
+ }))
+
+ instanceID := "12"
+ require.False(t, paymentOrderAllowsRegistryFallback(&dbent.PaymentOrder{
+ PaymentType: payment.TypeAlipay,
+ ProviderInstanceID: &instanceID,
+ }))
+
+ require.False(t, paymentOrderAllowsRegistryFallback(&dbent.PaymentOrder{
+ PaymentType: payment.TypeAlipay,
+ ProviderSnapshot: map[string]any{
+ "schema_version": 2,
+ "provider_instance_id": "12",
+ },
+ }))
+}
+
+func TestPaymentOrderQueryReferenceUsesOutTradeNoForOfficialProviders(t *testing.T) {
+ t.Parallel()
+
+ order := &dbent.PaymentOrder{
+ PaymentType: payment.TypeWxpay,
+ OutTradeNo: "sub2_out_trade_no",
+ PaymentTradeNo: "wx-transaction-id",
+ }
+
+ require.Equal(t, "sub2_out_trade_no", paymentOrderQueryReference(order, &paymentOrderLifecycleQueryProvider{}))
+ require.Equal(t, "sub2_out_trade_no", paymentOrderQueryReference(order, paymentFulfillmentTestProvider{
+ key: payment.TypeWxpay,
+ }))
+}
+
func newPaymentOrderLifecycleTestClient(t *testing.T) *dbent.Client {
t.Helper()
--
GitLab
From ebd053c87ea752fa1845944c303ed09ea6b181be Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 13:07:40 +0800
Subject: [PATCH 137/261] docs: clarify openai scheduler flag semantics
---
frontend/src/views/admin/SettingsView.vue | 6 +++---
.../src/views/admin/__tests__/SettingsView.spec.ts | 10 ++++++++++
2 files changed, 13 insertions(+), 3 deletions(-)
diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue
index 9fb8da41..e56afe5f 100644
--- a/frontend/src/views/admin/SettingsView.vue
+++ b/frontend/src/views/admin/SettingsView.vue
@@ -1877,13 +1877,13 @@
- {{ localText('OpenAI 高级调度器', 'OpenAI advanced scheduler') }}
+ {{ localText('OpenAI 实验调度策略', 'OpenAI experimental scheduler policy') }}
{{
localText(
- '切换 OpenAI 侧新增的高级调度开关,供当前分支实验性调度逻辑使用。',
- 'Toggles the new OpenAI advanced scheduler flag for the experimental routing logic on this branch.'
+ '默认关闭。开启后仅影响本网关在 OpenAI 账号间的实验性调度选择逻辑,不代表上游 OpenAI 官方能力。',
+ 'Disabled by default. When enabled, this only changes the gateway\'s experimental account-selection policy for OpenAI traffic; it does not indicate an upstream OpenAI capability.'
)
}}
diff --git a/frontend/src/views/admin/__tests__/SettingsView.spec.ts b/frontend/src/views/admin/__tests__/SettingsView.spec.ts
index b6f8ab17..3541d994 100644
--- a/frontend/src/views/admin/__tests__/SettingsView.spec.ts
+++ b/frontend/src/views/admin/__tests__/SettingsView.spec.ts
@@ -472,4 +472,14 @@ describe('admin SettingsView payment visible method controls', () => {
expect(showError).toHaveBeenCalled()
expect(String(showError.mock.calls.at(-1)?.[0] ?? '')).toContain('支付来源')
})
+
+ it('renders advanced scheduler copy as local experimental gateway policy', async () => {
+ const wrapper = mountView()
+
+ await flushPromises()
+
+ expect(wrapper.text()).toContain('OpenAI 实验调度策略')
+ expect(wrapper.text()).toContain('默认关闭。开启后仅影响本网关在 OpenAI 账号间的实验性调度选择逻辑')
+ expect(wrapper.text()).not.toContain('OpenAI 高级调度器')
+ })
})
--
GitLab
From 267844ebe63261e4ebfe6bb485fb06fc32df4719 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 13:10:59 +0800
Subject: [PATCH 138/261] fix: fail closed for legacy refund provider
resolution
---
backend/internal/service/payment_refund.go | 44 ++++++-
.../internal/service/payment_refund_test.go | 117 ++++++++++++++++++
2 files changed, 158 insertions(+), 3 deletions(-)
create mode 100644 backend/internal/service/payment_refund_test.go
diff --git a/backend/internal/service/payment_refund.go b/backend/internal/service/payment_refund.go
index fbaeff99..6883056c 100644
--- a/backend/internal/service/payment_refund.go
+++ b/backend/internal/service/payment_refund.go
@@ -43,6 +43,37 @@ func (s *PaymentService) getOrderProviderInstance(ctx context.Context, o *dbent.
return s.entClient.PaymentProviderInstance.Get(ctx, instID)
}
+// getRefundOrderProviderInstance resolves the provider instance for refund paths.
+// Refunds must be pinned to an explicit historical binding, so legacy
+// "best-effort" provider guessing is intentionally not allowed here.
+func (s *PaymentService) getRefundOrderProviderInstance(ctx context.Context, o *dbent.PaymentOrder) (*dbent.PaymentProviderInstance, error) {
+ if s == nil || s.entClient == nil || o == nil {
+ return nil, nil
+ }
+
+ if snapshot := psOrderProviderSnapshot(o); snapshot != nil {
+ return s.resolveSnapshotOrderProviderInstance(ctx, o, snapshot)
+ }
+
+ instIDStr := strings.TrimSpace(psStringValue(o.ProviderInstanceID))
+ if instIDStr == "" {
+ return nil, nil
+ }
+
+ instID, err := strconv.ParseInt(instIDStr, 10, 64)
+ if err != nil {
+ return nil, fmt.Errorf("order %d refund provider instance id is invalid: %s", o.ID, instIDStr)
+ }
+ inst, err := s.entClient.PaymentProviderInstance.Get(ctx, instID)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, fmt.Errorf("order %d refund provider instance %s is missing", o.ID, instIDStr)
+ }
+ return nil, err
+ }
+ return inst, nil
+}
+
func (s *PaymentService) resolveUniqueLegacyOrderProviderInstance(ctx context.Context, o *dbent.PaymentOrder) (*dbent.PaymentProviderInstance, error) {
paymentType := payment.GetBasePaymentType(strings.TrimSpace(o.PaymentType))
providerKey := strings.TrimSpace(psStringValue(o.ProviderKey))
@@ -157,7 +188,7 @@ func (s *PaymentService) validateRefundRequest(ctx context.Context, oid, uid int
return nil, infraerrors.BadRequest("INVALID_STATUS", "only completed orders can request refund")
}
// Check provider instance allows user refund
- inst, err := s.getOrderProviderInstance(ctx, o)
+ inst, err := s.getRefundOrderProviderInstance(ctx, o)
if err != nil || inst == nil {
return nil, infraerrors.Forbidden("USER_REFUND_DISABLED", "refund is not available for this order")
}
@@ -177,7 +208,7 @@ func (s *PaymentService) PrepareRefund(ctx context.Context, oid int64, amt float
return nil, nil, infraerrors.BadRequest("INVALID_STATUS", "order status does not allow refund")
}
// Check provider instance allows admin refund
- inst, instErr := s.getOrderProviderInstance(ctx, o)
+ inst, instErr := s.getRefundOrderProviderInstance(ctx, o)
if instErr != nil {
slog.Warn("refund: provider instance lookup failed", "orderID", oid, "error", instErr)
return nil, nil, infraerrors.InternalServer("PROVIDER_LOOKUP_FAILED", "failed to look up payment provider for this order")
@@ -314,7 +345,14 @@ func (s *PaymentService) gwRefund(ctx context.Context, p *RefundPlan) error {
// getRefundProvider creates a provider using the order's original instance config.
// Delegates to getOrderProvider which handles instance lookup and fallback.
func (s *PaymentService) getRefundProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) {
- return s.getOrderProvider(ctx, o)
+ inst, err := s.getRefundOrderProviderInstance(ctx, o)
+ if err != nil {
+ return nil, err
+ }
+ if inst == nil {
+ return nil, fmt.Errorf("refund provider instance is unavailable for order %d", o.ID)
+ }
+ return s.createProviderFromInstance(ctx, inst)
}
func (s *PaymentService) handleGwFail(ctx context.Context, p *RefundPlan, gErr error) (*RefundResult, error) {
diff --git a/backend/internal/service/payment_refund_test.go b/backend/internal/service/payment_refund_test.go
new file mode 100644
index 00000000..95104618
--- /dev/null
+++ b/backend/internal/service/payment_refund_test.go
@@ -0,0 +1,117 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/stretchr/testify/require"
+)
+
+func TestValidateRefundRequestRejectsLegacyGuessedProviderInstance(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+
+ user, err := client.User.Create().
+ SetEmail("refund-legacy@example.com").
+ SetPasswordHash("hash").
+ SetUsername("refund-legacy-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeAlipay).
+ SetName("alipay-refund-instance").
+ SetConfig("{}").
+ SetSupportedTypes("alipay").
+ SetEnabled(true).
+ SetAllowUserRefund(true).
+ SetRefundEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(88).
+ SetPayAmount(88).
+ SetFeeRate(0).
+ SetRechargeCode("REFUND-LEGACY-ORDER").
+ SetOutTradeNo("sub2_refund_legacy_order").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("trade-legacy-refund").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusCompleted).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetPaidAt(time.Now()).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ Save(ctx)
+ require.NoError(t, err)
+
+ svc := &PaymentService{
+ entClient: client,
+ }
+
+ _, err = svc.validateRefundRequest(ctx, order.ID, user.ID)
+ require.Error(t, err)
+ require.Equal(t, "USER_REFUND_DISABLED", infraerrors.Reason(err))
+}
+
+func TestPrepareRefundRejectsLegacyGuessedProviderInstance(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+
+ user, err := client.User.Create().
+ SetEmail("refund-legacy-admin@example.com").
+ SetPasswordHash("hash").
+ SetUsername("refund-legacy-admin-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeAlipay).
+ SetName("alipay-refund-admin-instance").
+ SetConfig("{}").
+ SetSupportedTypes("alipay").
+ SetEnabled(true).
+ SetAllowUserRefund(true).
+ SetRefundEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(188).
+ SetPayAmount(188).
+ SetFeeRate(0).
+ SetRechargeCode("REFUND-LEGACY-ADMIN-ORDER").
+ SetOutTradeNo("sub2_refund_legacy_admin_order").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("trade-legacy-admin-refund").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusCompleted).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetPaidAt(time.Now()).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ Save(ctx)
+ require.NoError(t, err)
+
+ svc := &PaymentService{
+ entClient: client,
+ }
+
+ plan, result, err := svc.PrepareRefund(ctx, order.ID, 0, "", false, false)
+ require.Nil(t, plan)
+ require.Nil(t, result)
+ require.Error(t, err)
+ require.Equal(t, "REFUND_DISABLED", infraerrors.Reason(err))
+}
--
GitLab
From 0934f737d5c46a0451844c161b4c6e69bd2050d9 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 13:35:54 +0800
Subject: [PATCH 139/261] fix: snapshot merchant identity for alipay and
easypay
---
backend/internal/payment/provider/alipay.go | 44 ++++++++---
.../internal/payment/provider/alipay_test.go | 15 ++++
backend/internal/payment/provider/easypay.go | 28 ++++++-
.../payment/provider/easypay_sign_test.go | 15 ++++
.../internal/payment/provider/wxpay_test.go | 2 +-
backend/internal/payment/types.go | 6 ++
.../internal/service/payment_fulfillment.go | 42 +---------
.../service/payment_fulfillment_test.go | 34 ++++++++
backend/internal/service/payment_order.go | 10 +++
.../payment_order_provider_snapshot.go | 78 +++++++++++++++++++
.../payment_order_provider_snapshot_test.go | 34 ++++++++
backend/internal/service/payment_refund.go | 6 ++
.../internal/service/payment_refund_test.go | 69 ++++++++++++++++
13 files changed, 328 insertions(+), 55 deletions(-)
diff --git a/backend/internal/payment/provider/alipay.go b/backend/internal/payment/provider/alipay.go
index 4f87e5a7..0604883a 100644
--- a/backend/internal/payment/provider/alipay.go
+++ b/backend/internal/payment/provider/alipay.go
@@ -91,6 +91,17 @@ func (a *Alipay) SupportedTypes() []payment.PaymentType {
return []payment.PaymentType{payment.TypeAlipay}
}
+func (a *Alipay) MerchantIdentityMetadata() map[string]string {
+ if a == nil {
+ return nil
+ }
+ appID := strings.TrimSpace(a.config["appId"])
+ if appID == "" {
+ return nil
+ }
+ return map[string]string{"app_id": appID}
+}
+
// CreatePayment creates an Alipay payment page URL.
func (a *Alipay) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
client, err := a.getClient()
@@ -181,10 +192,11 @@ func (a *Alipay) QueryOrder(ctx context.Context, tradeNo string) (*payment.Query
}
return &payment.QueryOrderResponse{
- TradeNo: result.TradeNo,
- Status: status,
- Amount: amount,
- PaidAt: result.SendPayDate,
+ TradeNo: result.TradeNo,
+ Status: status,
+ Amount: amount,
+ PaidAt: result.SendPayDate,
+ Metadata: a.MerchantIdentityMetadata(),
}, nil
}
@@ -215,12 +227,21 @@ func (a *Alipay) VerifyNotification(ctx context.Context, rawBody string, _ map[s
return nil, fmt.Errorf("alipay parse notification amount %q: %w", notification.TotalAmount, err)
}
+ metadata := a.MerchantIdentityMetadata()
+ if appID := strings.TrimSpace(notification.AppId); appID != "" {
+ if metadata == nil {
+ metadata = map[string]string{}
+ }
+ metadata["app_id"] = appID
+ }
+
return &payment.PaymentNotification{
- TradeNo: notification.TradeNo,
- OrderID: notification.OutTradeNo,
- Amount: amount,
- Status: status,
- RawData: rawBody,
+ TradeNo: notification.TradeNo,
+ OrderID: notification.OutTradeNo,
+ Amount: amount,
+ Status: status,
+ RawData: rawBody,
+ Metadata: metadata,
}, nil
}
@@ -283,6 +304,7 @@ func isTradeNotExist(err error) bool {
// Ensure interface compliance.
var (
- _ payment.Provider = (*Alipay)(nil)
- _ payment.CancelableProvider = (*Alipay)(nil)
+ _ payment.Provider = (*Alipay)(nil)
+ _ payment.CancelableProvider = (*Alipay)(nil)
+ _ payment.MerchantIdentityProvider = (*Alipay)(nil)
)
diff --git a/backend/internal/payment/provider/alipay_test.go b/backend/internal/payment/provider/alipay_test.go
index 6cc4246c..b25c05bd 100644
--- a/backend/internal/payment/provider/alipay_test.go
+++ b/backend/internal/payment/provider/alipay_test.go
@@ -243,3 +243,18 @@ func TestCreateTradeUsesWapPayForMobile(t *testing.T) {
t.Fatalf("qr_code = %q, want empty", resp.QRCode)
}
}
+
+func TestAlipayMerchantIdentityMetadata(t *testing.T) {
+ t.Parallel()
+
+ provider := &Alipay{
+ config: map[string]string{
+ "appId": "2021001234567890",
+ },
+ }
+
+ metadata := provider.MerchantIdentityMetadata()
+ if metadata["app_id"] != "2021001234567890" {
+ t.Fatalf("app_id = %q, want %q", metadata["app_id"], "2021001234567890")
+ }
+}
diff --git a/backend/internal/payment/provider/easypay.go b/backend/internal/payment/provider/easypay.go
index e33a567d..37bd38b2 100644
--- a/backend/internal/payment/provider/easypay.go
+++ b/backend/internal/payment/provider/easypay.go
@@ -59,6 +59,17 @@ func (e *EasyPay) SupportedTypes() []payment.PaymentType {
return []payment.PaymentType{payment.TypeAlipay, payment.TypeWxpay}
}
+func (e *EasyPay) MerchantIdentityMetadata() map[string]string {
+ if e == nil {
+ return nil
+ }
+ pid := strings.TrimSpace(e.config["pid"])
+ if pid == "" {
+ return nil
+ }
+ return map[string]string{"pid": pid}
+}
+
func (e *EasyPay) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
// Payment mode determined by instance config, not payment type.
// "popup" → hosted page (submit.php); "qrcode"/default → API call (mapi.php).
@@ -178,7 +189,12 @@ func (e *EasyPay) QueryOrder(ctx context.Context, tradeNo string) (*payment.Quer
status = payment.ProviderStatusPaid
}
amount, _ := strconv.ParseFloat(resp.Money, 64)
- return &payment.QueryOrderResponse{TradeNo: tradeNo, Status: status, Amount: amount}, nil
+ return &payment.QueryOrderResponse{
+ TradeNo: tradeNo,
+ Status: status,
+ Amount: amount,
+ Metadata: e.MerchantIdentityMetadata(),
+ }, nil
}
func (e *EasyPay) VerifyNotification(_ context.Context, rawBody string, _ map[string]string) (*payment.PaymentNotification, error) {
@@ -203,9 +219,17 @@ func (e *EasyPay) VerifyNotification(_ context.Context, rawBody string, _ map[st
status = payment.ProviderStatusSuccess
}
amount, _ := strconv.ParseFloat(params["money"], 64)
+
+ metadata := e.MerchantIdentityMetadata()
+ if pid := strings.TrimSpace(params["pid"]); pid != "" {
+ if metadata == nil {
+ metadata = map[string]string{}
+ }
+ metadata["pid"] = pid
+ }
return &payment.PaymentNotification{
TradeNo: params["trade_no"], OrderID: params["out_trade_no"],
- Amount: amount, Status: status, RawData: rawBody,
+ Amount: amount, Status: status, RawData: rawBody, Metadata: metadata,
}, nil
}
diff --git a/backend/internal/payment/provider/easypay_sign_test.go b/backend/internal/payment/provider/easypay_sign_test.go
index 146a6fa1..8328d294 100644
--- a/backend/internal/payment/provider/easypay_sign_test.go
+++ b/backend/internal/payment/provider/easypay_sign_test.go
@@ -178,3 +178,18 @@ func TestEasyPayVerifySignWrongSignValue(t *testing.T) {
t.Fatal("easyPayVerifySign should return false for an incorrect sign value")
}
}
+
+func TestEasyPayMerchantIdentityMetadata(t *testing.T) {
+ t.Parallel()
+
+ provider := &EasyPay{
+ config: map[string]string{
+ "pid": "1001",
+ },
+ }
+
+ metadata := provider.MerchantIdentityMetadata()
+ if metadata["pid"] != "1001" {
+ t.Fatalf("pid = %q, want %q", metadata["pid"], "1001")
+ }
+}
diff --git a/backend/internal/payment/provider/wxpay_test.go b/backend/internal/payment/provider/wxpay_test.go
index b3f4f648..0d79b1b0 100644
--- a/backend/internal/payment/provider/wxpay_test.go
+++ b/backend/internal/payment/provider/wxpay_test.go
@@ -110,7 +110,7 @@ func TestBuildWxpayTransactionMetadata(t *testing.T) {
Appid: strPtr("wx-app-id"),
Mchid: strPtr("mch-id"),
TradeState: strPtr(wxpayTradeStateSuccess),
- Amount: &payments.Amount{
+ Amount: &payments.TransactionAmount{
Currency: strPtr(wxpayCurrency),
},
}
diff --git a/backend/internal/payment/types.go b/backend/internal/payment/types.go
index 29abf82b..e7ac6727 100644
--- a/backend/internal/payment/types.go
+++ b/backend/internal/payment/types.go
@@ -214,3 +214,9 @@ type CancelableProvider interface {
// CancelPayment cancels/expires a pending payment on the upstream platform.
CancelPayment(ctx context.Context, tradeNo string) error
}
+
+// MerchantIdentityProvider exposes the current non-sensitive merchant identity
+// derived from provider configuration for snapshot consistency checks.
+type MerchantIdentityProvider interface {
+ MerchantIdentityMetadata() map[string]string
+}
diff --git a/backend/internal/service/payment_fulfillment.go b/backend/internal/service/payment_fulfillment.go
index 423ed80f..904960ee 100644
--- a/backend/internal/service/payment_fulfillment.go
+++ b/backend/internal/service/payment_fulfillment.go
@@ -96,47 +96,7 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo
}
func validateProviderNotificationMetadata(order *dbent.PaymentOrder, providerKey string, metadata map[string]string) error {
- if order == nil || len(metadata) == 0 || !strings.EqualFold(strings.TrimSpace(providerKey), payment.TypeWxpay) {
- return nil
- }
-
- snapshot := psOrderProviderSnapshot(order)
- if snapshot == nil {
- return nil
- }
-
- if expected := strings.TrimSpace(snapshot.MerchantAppID); expected != "" {
- actual := strings.TrimSpace(metadata["appid"])
- if actual == "" {
- return fmt.Errorf("wxpay notification missing appid")
- }
- if !strings.EqualFold(expected, actual) {
- return fmt.Errorf("wxpay appid mismatch: expected %s, got %s", expected, actual)
- }
- }
- if expected := strings.TrimSpace(snapshot.MerchantID); expected != "" {
- actual := strings.TrimSpace(metadata["mchid"])
- if actual == "" {
- return fmt.Errorf("wxpay notification missing mchid")
- }
- if !strings.EqualFold(expected, actual) {
- return fmt.Errorf("wxpay mchid mismatch: expected %s, got %s", expected, actual)
- }
- }
- if expected := strings.TrimSpace(snapshot.Currency); expected != "" {
- actual := strings.ToUpper(strings.TrimSpace(metadata["currency"]))
- if actual == "" {
- return fmt.Errorf("wxpay notification missing currency")
- }
- if !strings.EqualFold(expected, actual) {
- return fmt.Errorf("wxpay currency mismatch: expected %s, got %s", expected, actual)
- }
- }
- if actual := strings.TrimSpace(metadata["trade_state"]); actual != "" && !strings.EqualFold(actual, "SUCCESS") {
- return fmt.Errorf("wxpay trade_state mismatch: expected SUCCESS, got %s", actual)
- }
-
- return nil
+ return validateProviderSnapshotMetadata(order, providerKey, metadata)
}
func expectedNotificationProviderKey(registry *payment.Registry, orderPaymentType string, orderProviderKey string, instanceProviderKey string) string {
diff --git a/backend/internal/service/payment_fulfillment_test.go b/backend/internal/service/payment_fulfillment_test.go
index d70f8946..6aed19f8 100644
--- a/backend/internal/service/payment_fulfillment_test.go
+++ b/backend/internal/service/payment_fulfillment_test.go
@@ -321,3 +321,37 @@ func TestParseLegacyPaymentOrderID(t *testing.T) {
_, ok = parseLegacyPaymentOrderID("sub2_42", errors.New("db down"))
assert.False(t, ok)
}
+
+func TestValidateProviderNotificationMetadataRejectsAlipaySnapshotMismatch(t *testing.T) {
+ t.Parallel()
+
+ order := &dbent.PaymentOrder{
+ PaymentType: payment.TypeAlipay,
+ ProviderSnapshot: map[string]any{
+ "schema_version": 2,
+ "merchant_app_id": "alipay-app-expected",
+ },
+ }
+
+ err := validateProviderNotificationMetadata(order, payment.TypeAlipay, map[string]string{
+ "app_id": "alipay-app-other",
+ })
+ assert.ErrorContains(t, err, "alipay app_id mismatch")
+}
+
+func TestValidateProviderNotificationMetadataRejectsEasyPaySnapshotMismatch(t *testing.T) {
+ t.Parallel()
+
+ order := &dbent.PaymentOrder{
+ PaymentType: payment.TypeAlipay,
+ ProviderSnapshot: map[string]any{
+ "schema_version": 2,
+ "merchant_id": "pid-expected",
+ },
+ }
+
+ err := validateProviderNotificationMetadata(order, payment.TypeEasyPay, map[string]string{
+ "pid": "pid-other",
+ })
+ assert.ErrorContains(t, err, "easypay pid mismatch")
+}
diff --git a/backend/internal/service/payment_order.go b/backend/internal/service/payment_order.go
index 6ee490a8..254af5fe 100644
--- a/backend/internal/service/payment_order.go
+++ b/backend/internal/service/payment_order.go
@@ -240,6 +240,16 @@ func buildPaymentOrderProviderSnapshot(sel *payment.InstanceSelection, req Creat
}
snapshot["currency"] = "CNY"
}
+ if providerKey == payment.TypeAlipay {
+ if merchantAppID := strings.TrimSpace(sel.Config["appId"]); merchantAppID != "" {
+ snapshot["merchant_app_id"] = merchantAppID
+ }
+ }
+ if providerKey == payment.TypeEasyPay {
+ if merchantID := strings.TrimSpace(sel.Config["pid"]); merchantID != "" {
+ snapshot["merchant_id"] = merchantID
+ }
+ }
if len(snapshot) == 1 {
return nil
diff --git a/backend/internal/service/payment_order_provider_snapshot.go b/backend/internal/service/payment_order_provider_snapshot.go
index 31a790c7..bb60f9e2 100644
--- a/backend/internal/service/payment_order_provider_snapshot.go
+++ b/backend/internal/service/payment_order_provider_snapshot.go
@@ -125,3 +125,81 @@ func expectedNotificationProviderKeyForOrder(registry *payment.Registry, order *
return expectedNotificationProviderKey(registry, order.PaymentType, orderProviderKey, instanceProviderKey)
}
+
+func validateProviderSnapshotMetadata(order *dbent.PaymentOrder, providerKey string, metadata map[string]string) error {
+ if order == nil || len(metadata) == 0 {
+ return nil
+ }
+
+ snapshot := psOrderProviderSnapshot(order)
+ if snapshot == nil {
+ return nil
+ }
+
+ switch strings.TrimSpace(providerKey) {
+ case payment.TypeWxpay:
+ if expected := strings.TrimSpace(snapshot.MerchantAppID); expected != "" {
+ actual := strings.TrimSpace(metadata["appid"])
+ if actual == "" {
+ return fmt.Errorf("wxpay notification missing appid")
+ }
+ if !strings.EqualFold(expected, actual) {
+ return fmt.Errorf("wxpay appid mismatch: expected %s, got %s", expected, actual)
+ }
+ }
+ if expected := strings.TrimSpace(snapshot.MerchantID); expected != "" {
+ actual := strings.TrimSpace(metadata["mchid"])
+ if actual == "" {
+ return fmt.Errorf("wxpay notification missing mchid")
+ }
+ if !strings.EqualFold(expected, actual) {
+ return fmt.Errorf("wxpay mchid mismatch: expected %s, got %s", expected, actual)
+ }
+ }
+ if expected := strings.TrimSpace(snapshot.Currency); expected != "" {
+ actual := strings.ToUpper(strings.TrimSpace(metadata["currency"]))
+ if actual == "" {
+ return fmt.Errorf("wxpay notification missing currency")
+ }
+ if !strings.EqualFold(expected, actual) {
+ return fmt.Errorf("wxpay currency mismatch: expected %s, got %s", expected, actual)
+ }
+ }
+ if actual := strings.TrimSpace(metadata["trade_state"]); actual != "" && !strings.EqualFold(actual, "SUCCESS") {
+ return fmt.Errorf("wxpay trade_state mismatch: expected SUCCESS, got %s", actual)
+ }
+ case payment.TypeAlipay:
+ if expected := strings.TrimSpace(snapshot.MerchantAppID); expected != "" {
+ actual := strings.TrimSpace(metadata["app_id"])
+ if actual == "" {
+ return fmt.Errorf("alipay app_id missing")
+ }
+ if !strings.EqualFold(expected, actual) {
+ return fmt.Errorf("alipay app_id mismatch: expected %s, got %s", expected, actual)
+ }
+ }
+ case payment.TypeEasyPay:
+ if expected := strings.TrimSpace(snapshot.MerchantID); expected != "" {
+ actual := strings.TrimSpace(metadata["pid"])
+ if actual == "" {
+ return fmt.Errorf("easypay pid missing")
+ }
+ if !strings.EqualFold(expected, actual) {
+ return fmt.Errorf("easypay pid mismatch: expected %s, got %s", expected, actual)
+ }
+ }
+ }
+
+ return nil
+}
+
+func providerMerchantIdentityMetadata(prov payment.Provider) map[string]string {
+ if prov == nil {
+ return nil
+ }
+ reporter, ok := prov.(payment.MerchantIdentityProvider)
+ if !ok {
+ return nil
+ }
+ return reporter.MerchantIdentityMetadata()
+}
diff --git a/backend/internal/service/payment_order_provider_snapshot_test.go b/backend/internal/service/payment_order_provider_snapshot_test.go
index bc6666a8..efa013b5 100644
--- a/backend/internal/service/payment_order_provider_snapshot_test.go
+++ b/backend/internal/service/payment_order_provider_snapshot_test.go
@@ -130,6 +130,40 @@ func TestBuildPaymentOrderProviderSnapshot_UsesWxpayJSAPIAppIDForOpenIDOrders(t
require.Equal(t, "CNY", snapshot["currency"])
}
+func TestBuildPaymentOrderProviderSnapshot_IncludesAlipayMerchantIdentity(t *testing.T) {
+ t.Parallel()
+
+ snapshot := buildPaymentOrderProviderSnapshot(&payment.InstanceSelection{
+ InstanceID: "21",
+ ProviderKey: payment.TypeAlipay,
+ Config: map[string]string{
+ "appId": "alipay-app-21",
+ "privateKey": "secret",
+ },
+ PaymentMode: "redirect",
+ }, CreateOrderRequest{})
+
+ require.Equal(t, "alipay-app-21", snapshot["merchant_app_id"])
+ require.NotContains(t, snapshot, "privateKey")
+}
+
+func TestBuildPaymentOrderProviderSnapshot_IncludesEasyPayMerchantIdentity(t *testing.T) {
+ t.Parallel()
+
+ snapshot := buildPaymentOrderProviderSnapshot(&payment.InstanceSelection{
+ InstanceID: "66",
+ ProviderKey: payment.TypeEasyPay,
+ Config: map[string]string{
+ "pid": "easypay-merchant-66",
+ "pkey": "secret",
+ },
+ PaymentMode: "popup",
+ }, CreateOrderRequest{PaymentType: payment.TypeAlipay})
+
+ require.Equal(t, "easypay-merchant-66", snapshot["merchant_id"])
+ require.NotContains(t, snapshot, "pkey")
+}
+
func valueOrEmpty(v *string) string {
if v == nil {
return ""
diff --git a/backend/internal/service/payment_refund.go b/backend/internal/service/payment_refund.go
index 6883056c..7521878c 100644
--- a/backend/internal/service/payment_refund.go
+++ b/backend/internal/service/payment_refund.go
@@ -333,6 +333,12 @@ func (s *PaymentService) gwRefund(ctx context.Context, p *RefundPlan) error {
if err != nil {
return fmt.Errorf("get refund provider: %w", err)
}
+ if err := validateProviderSnapshotMetadata(p.Order, prov.ProviderKey(), providerMerchantIdentityMetadata(prov)); err != nil {
+ s.writeAuditLog(ctx, p.Order.ID, "REFUND_PROVIDER_METADATA_MISMATCH", "admin", map[string]any{
+ "detail": err.Error(),
+ })
+ return err
+ }
_, err = prov.Refund(ctx, payment.RefundRequest{
TradeNo: p.Order.PaymentTradeNo,
OrderID: p.Order.OutTradeNo,
diff --git a/backend/internal/service/payment_refund_test.go b/backend/internal/service/payment_refund_test.go
index 95104618..ca5b62cb 100644
--- a/backend/internal/service/payment_refund_test.go
+++ b/backend/internal/service/payment_refund_test.go
@@ -4,6 +4,7 @@ package service
import (
"context"
+ "strconv"
"testing"
"time"
@@ -115,3 +116,71 @@ func TestPrepareRefundRejectsLegacyGuessedProviderInstance(t *testing.T) {
require.Error(t, err)
require.Equal(t, "REFUND_DISABLED", infraerrors.Reason(err))
}
+
+func TestGwRefundRejectsAlipayMerchantIdentitySnapshotMismatch(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+
+ user, err := client.User.Create().
+ SetEmail("refund-snapshot-mismatch@example.com").
+ SetPasswordHash("hash").
+ SetUsername("refund-snapshot-mismatch-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ inst, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeAlipay).
+ SetName("alipay-refund-mismatch-instance").
+ SetConfig(encryptWebhookProviderConfig(t, map[string]string{
+ "appId": "runtime-alipay-app",
+ "privateKey": "runtime-private-key",
+ })).
+ SetSupportedTypes("alipay").
+ SetEnabled(true).
+ SetRefundEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ instID := strconv.FormatInt(inst.ID, 10)
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(88).
+ SetPayAmount(88).
+ SetFeeRate(0).
+ SetRechargeCode("REFUND-SNAPSHOT-MISMATCH-ORDER").
+ SetOutTradeNo("sub2_refund_snapshot_mismatch_order").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("trade-refund-snapshot-mismatch").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusCompleted).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetPaidAt(time.Now()).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ SetProviderInstanceID(instID).
+ SetProviderKey(payment.TypeAlipay).
+ SetProviderSnapshot(map[string]any{
+ "schema_version": 2,
+ "provider_instance_id": instID,
+ "provider_key": payment.TypeAlipay,
+ "merchant_app_id": "expected-alipay-app",
+ }).
+ Save(ctx)
+ require.NoError(t, err)
+
+ svc := &PaymentService{
+ entClient: client,
+ loadBalancer: newWebhookProviderTestLoadBalancer(client),
+ }
+
+ err = svc.gwRefund(ctx, &RefundPlan{
+ OrderID: order.ID,
+ Order: order,
+ RefundAmount: order.Amount,
+ GatewayAmount: order.Amount,
+ Reason: "snapshot mismatch",
+ })
+ require.ErrorContains(t, err, "alipay app_id mismatch")
+}
--
GitLab
From 12f1e19d688f1cbb892a152572eb793f5e7d097d Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 13:36:19 +0800
Subject: [PATCH 140/261] fix: restore wechat oauth legacy callback
compatibility
---
.../src/views/auth/WechatCallbackView.vue | 107 ++++++++++++++---
.../auth/__tests__/WechatCallbackView.spec.ts | 111 ++++++++++++++++++
2 files changed, 203 insertions(+), 15 deletions(-)
diff --git a/frontend/src/views/auth/WechatCallbackView.vue b/frontend/src/views/auth/WechatCallbackView.vue
index 36e3140c..a3efcf9a 100644
--- a/frontend/src/views/auth/WechatCallbackView.vue
+++ b/frontend/src/views/auth/WechatCallbackView.vue
@@ -300,6 +300,7 @@ import {
persistOAuthTokenContext,
resolveWeChatOAuthStartStrict,
type OAuthAdoptionDecision,
+ type OAuthTokenResponse,
type PendingOAuthExchangeResponse
} from '@/api/auth'
@@ -328,6 +329,7 @@ const pendingAccountAction = ref<'none' | 'create_account' | 'bind_login'>('none
const pendingAccountEmail = ref('')
const bindLoginEmail = ref('')
const bindLoginPassword = ref('')
+const legacyPendingOAuthToken = ref('')
const accountActionError = ref('')
const canReturnToCreateAccount = ref(false)
const needsTotpChallenge = ref(false)
@@ -354,12 +356,49 @@ type PendingWeChatCompletion = PendingOAuthExchangeResponse & {
user_email_masked?: string
}
+function persistPendingAuthSession(redirect?: string) {
+ authStore.setPendingAuthSession({
+ token: '',
+ token_field: 'pending_oauth_token',
+ provider: 'wechat',
+ redirect: sanitizeRedirectPath(redirect || redirectTo.value)
+ })
+}
+
+function clearPendingAuthSession() {
+ authStore.clearPendingAuthSession()
+}
+
function parseFragmentParams(): URLSearchParams {
const raw = typeof window !== 'undefined' ? window.location.hash : ''
const hash = raw.startsWith('#') ? raw.slice(1) : raw
return new URLSearchParams(hash)
}
+function readLegacyFragmentLogin(params: URLSearchParams): OAuthTokenResponse | null {
+ const accessToken = params.get('access_token')?.trim() || ''
+ if (!accessToken) {
+ return null
+ }
+
+ const completion: OAuthTokenResponse = {
+ access_token: accessToken
+ }
+ const refreshToken = params.get('refresh_token')?.trim() || ''
+ if (refreshToken) {
+ completion.refresh_token = refreshToken
+ }
+ const expiresIn = Number.parseInt(params.get('expires_in')?.trim() || '', 10)
+ if (Number.isFinite(expiresIn) && expiresIn > 0) {
+ completion.expires_in = expiresIn
+ }
+ const tokenType = params.get('token_type')?.trim() || ''
+ if (tokenType) {
+ completion.token_type = tokenType
+ }
+ return completion
+}
+
function sanitizeRedirectPath(path: string | null | undefined): string {
if (!path) return '/dashboard'
if (!path.startsWith('/')) return '/dashboard'
@@ -672,6 +711,7 @@ function isCreateAccountRecoveryError(error: unknown): boolean {
async function finalizeCompletion(completion: PendingOAuthExchangeResponse, redirect: string) {
if (getOAuthCompletionKind(completion) === 'bind') {
const bindRedirect = sanitizeRedirectPath(completion.redirect || '/profile')
+ clearPendingAuthSession()
appStore.showSuccess(bindSuccessMessage)
await router.replace(bindRedirect)
return
@@ -689,16 +729,19 @@ async function finalizeCompletion(completion: PendingOAuthExchangeResponse, redi
async function finalizePendingAccountResponse(completion: PendingWeChatCompletion) {
applyAdoptionSuggestionState(completion)
+ const redirect = sanitizeRedirectPath(completion.redirect || redirectTo.value)
if (completion.error === 'invitation_required') {
pendingAccountAction.value = 'none'
needsInvitation.value = true
needsAdoptionConfirmation.value = false
isProcessing.value = false
+ persistPendingAuthSession(redirect)
return
}
if (applyTotpChallenge(completion)) {
+ persistPendingAuthSession(redirect)
return
}
@@ -707,10 +750,10 @@ async function finalizePendingAccountResponse(completion: PendingWeChatCompletio
needsInvitation.value = false
needsAdoptionConfirmation.value = false
isProcessing.value = false
+ persistPendingAuthSession(redirect)
return
}
- const redirect = sanitizeRedirectPath(completion.redirect || redirectTo.value)
await finalizeCompletion(completion, redirect)
}
@@ -720,10 +763,18 @@ async function handleSubmitInvitation() {
isSubmitting.value = true
try {
- const tokenData = await completeWeChatOAuthRegistration(
- invitationCode.value.trim(),
- currentAdoptionDecision()
- )
+ const tokenData = legacyPendingOAuthToken.value
+ ? (
+ await apiClient.post('/auth/oauth/wechat/complete-registration', {
+ pending_oauth_token: legacyPendingOAuthToken.value,
+ invitation_code: invitationCode.value.trim(),
+ ...serializeAdoptionDecision(currentAdoptionDecision())
+ })
+ ).data
+ : await completeWeChatOAuthRegistration(
+ invitationCode.value.trim(),
+ currentAdoptionDecision()
+ )
persistOAuthTokenContext(tokenData)
await authStore.setToken(tokenData.access_token)
appStore.showSuccess(t('auth.loginSuccess'))
@@ -864,48 +915,74 @@ onMounted(async () => {
}
const params = parseFragmentParams()
+ const legacyLogin = readLegacyFragmentLogin(params)
+ const legacyPendingToken = params.get('pending_oauth_token')?.trim() || ''
const error = params.get('error')
const errorDesc = params.get('error_description') || params.get('error_message') || ''
-
- if (error) {
- errorMessage.value = errorDesc || error
- appStore.showError(errorMessage.value)
- isProcessing.value = false
- return
- }
+ const redirect = sanitizeRedirectPath(
+ params.get('redirect') || (route.query.redirect as string | undefined) || '/dashboard'
+ )
try {
+ if (legacyLogin) {
+ persistOAuthTokenContext(legacyLogin)
+ await authStore.setToken(legacyLogin.access_token)
+ appStore.showSuccess(t('auth.loginSuccess'))
+ await router.replace(redirect)
+ return
+ }
+
+ if (error === 'invitation_required' && legacyPendingToken) {
+ legacyPendingOAuthToken.value = legacyPendingToken
+ redirectTo.value = redirect
+ needsInvitation.value = true
+ isProcessing.value = false
+ return
+ }
+
+ if (error) {
+ errorMessage.value = errorDesc || error
+ appStore.showError(errorMessage.value)
+ isProcessing.value = false
+ return
+ }
+
const completion = await exchangePendingOAuthCompletion() as PendingWeChatCompletion
- const redirect = sanitizeRedirectPath(
+ const completionRedirect = sanitizeRedirectPath(
completion.redirect || (route.query.redirect as string | undefined) || '/dashboard'
)
applyAdoptionSuggestionState(completion)
- redirectTo.value = redirect
+ redirectTo.value = completionRedirect
if (completion.error === 'invitation_required') {
needsInvitation.value = true
isProcessing.value = false
+ persistPendingAuthSession(completionRedirect)
return
}
if (applyTotpChallenge(completion)) {
+ persistPendingAuthSession(completionRedirect)
return
}
applyPendingAccountAction(completion)
if (pendingAccountAction.value !== 'none') {
isProcessing.value = false
+ persistPendingAuthSession(completionRedirect)
return
}
if (adoptionRequired.value && hasSuggestedProfile(completion)) {
needsAdoptionConfirmation.value = true
isProcessing.value = false
+ persistPendingAuthSession(completionRedirect)
return
}
- await finalizeCompletion(completion, redirect)
+ await finalizeCompletion(completion, completionRedirect)
} catch (e: unknown) {
+ clearPendingAuthSession()
errorMessage.value = getRequestErrorMessage(e, t('auth.loginFailed'))
appStore.showError(errorMessage.value)
isProcessing.value = false
diff --git a/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts
index e02060f6..98a0268d 100644
--- a/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts
+++ b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts
@@ -14,6 +14,8 @@ const {
getAuthTokenMock,
replaceMock,
setTokenMock,
+ setPendingAuthSessionMock,
+ clearPendingAuthSessionMock,
showSuccessMock,
showErrorMock,
fetchPublicSettingsMock,
@@ -32,6 +34,8 @@ const {
getAuthTokenMock: vi.fn(),
replaceMock: vi.fn(),
setTokenMock: vi.fn(),
+ setPendingAuthSessionMock: vi.fn(),
+ clearPendingAuthSessionMock: vi.fn(),
showSuccessMock: vi.fn(),
showErrorMock: vi.fn(),
fetchPublicSettingsMock: vi.fn(),
@@ -111,6 +115,8 @@ vi.mock('vue-i18n', () => ({
vi.mock('@/stores', () => ({
useAuthStore: () => ({
setToken: setTokenMock,
+ setPendingAuthSession: setPendingAuthSessionMock,
+ clearPendingAuthSession: clearPendingAuthSessionMock,
}),
useAppStore: () => ({
...appStoreState,
@@ -152,6 +158,8 @@ describe('WechatCallbackView', () => {
getPublicSettingsMock.mockReset()
replaceMock.mockReset()
setTokenMock.mockReset()
+ setPendingAuthSessionMock.mockReset()
+ clearPendingAuthSessionMock.mockReset()
showSuccessMock.mockReset()
showErrorMock.mockReset()
prepareOAuthBindAccessTokenCookieMock.mockReset()
@@ -269,6 +277,81 @@ describe('WechatCallbackView', () => {
expect(locationState.current.href).toContain('mode=open')
})
+ it('accepts the legacy fragment token success callback without pending-session exchange', async () => {
+ locationState.current.hash =
+ '#access_token=legacy-access-token&refresh_token=legacy-refresh-token&expires_in=3600&token_type=Bearer&redirect=%2Flegacy-dashboard'
+ Object.defineProperty(window, 'location', {
+ configurable: true,
+ value: locationState.current,
+ })
+ setTokenMock.mockResolvedValue({})
+
+ mount(WechatCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false,
+ },
+ },
+ })
+
+ await flushPromises()
+
+ expect(exchangePendingOAuthCompletionMock).not.toHaveBeenCalled()
+ expect(setTokenMock).toHaveBeenCalledWith('legacy-access-token')
+ expect(localStorage.getItem('refresh_token')).toBe('legacy-refresh-token')
+ expect(localStorage.getItem('token_expires_at')).not.toBeNull()
+ expect(showSuccessMock).toHaveBeenCalledWith('Login success')
+ expect(replaceMock).toHaveBeenCalledWith('/legacy-dashboard')
+ })
+
+ it('accepts the legacy pending oauth invitation fragment without pending-session exchange', async () => {
+ locationState.current.hash =
+ '#error=invitation_required&pending_oauth_token=legacy-pending-token&redirect=%2Flegacy-invite'
+ Object.defineProperty(window, 'location', {
+ configurable: true,
+ value: locationState.current,
+ })
+ apiClientPostMock.mockResolvedValue({
+ data: {
+ access_token: 'legacy-access-token',
+ refresh_token: 'legacy-refresh-token',
+ expires_in: 3600,
+ token_type: 'Bearer',
+ },
+ })
+ setTokenMock.mockResolvedValue({})
+
+ const wrapper = mount(WechatCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false,
+ },
+ },
+ })
+
+ await flushPromises()
+
+ expect(exchangePendingOAuthCompletionMock).not.toHaveBeenCalled()
+ await wrapper.find('input[type="text"]').setValue('invite-code')
+ await wrapper.find('button').trigger('click')
+ await flushPromises()
+
+ expect(apiClientPostMock).toHaveBeenCalledWith('/auth/oauth/wechat/complete-registration', {
+ pending_oauth_token: 'legacy-pending-token',
+ invitation_code: 'invite-code',
+ adopt_display_name: true,
+ adopt_avatar: true,
+ })
+ expect(setTokenMock).toHaveBeenCalledWith('legacy-access-token')
+ expect(replaceMock).toHaveBeenCalledWith('/legacy-invite')
+ })
+
it('does not send adoption decisions during the initial exchange', async () => {
exchangePendingOAuthCompletionMock.mockResolvedValue({
access_token: 'access-token',
@@ -382,6 +465,7 @@ describe('WechatCallbackView', () => {
adoptAvatar: true,
})
expect(setTokenMock).not.toHaveBeenCalled()
+ expect(clearPendingAuthSessionMock).toHaveBeenCalledTimes(1)
expect(showSuccessMock).toHaveBeenCalledWith('profile.authBindings.bindSuccess')
expect(replaceMock).toHaveBeenCalledWith('/profile/connections')
})
@@ -548,6 +632,33 @@ describe('WechatCallbackView', () => {
expect(replaceMock).toHaveBeenCalledWith('/welcome')
})
+ it('persists a pending auth session when the oauth flow still needs account creation', async () => {
+ exchangePendingOAuthCompletionMock.mockResolvedValue({
+ error: 'email_required',
+ redirect: '/welcome',
+ })
+
+ mount(WechatCallbackView, {
+ global: {
+ stubs: {
+ AuthLayout: { template: '
' },
+ Icon: true,
+ RouterLink: { template: ' ' },
+ transition: false,
+ },
+ },
+ })
+
+ await flushPromises()
+
+ expect(setPendingAuthSessionMock).toHaveBeenCalledWith({
+ token: '',
+ token_field: 'pending_oauth_token',
+ provider: 'wechat',
+ redirect: '/welcome',
+ })
+ })
+
it('switches to bind-login when create-account returns EMAIL_EXISTS', async () => {
exchangePendingOAuthCompletionMock.mockResolvedValue({
error: 'email_required',
--
GitLab
From 65efef1eee696b295b11e550d424e6d76a944f76 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 13:47:15 +0800
Subject: [PATCH 141/261] feat: support replacing bound primary email
---
backend/internal/handler/user_handler.go | 2 +-
backend/internal/handler/user_handler_test.go | 53 +++++++
.../internal/service/auth_email_binding.go | 31 ++--
.../service/auth_service_email_bind_test.go | 142 ++++++++++++++++++
.../ProfileIdentityBindingsSection.vue | 28 +++-
.../ProfileIdentityBindingsSection.spec.ts | 68 +++++++++
frontend/src/i18n/locales/en.ts | 3 +
frontend/src/i18n/locales/zh.ts | 3 +
8 files changed, 313 insertions(+), 17 deletions(-)
diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go
index a6a7be9a..497a23c4 100644
--- a/backend/internal/handler/user_handler.go
+++ b/backend/internal/handler/user_handler.go
@@ -167,7 +167,7 @@ type StartIdentityBindingRequest struct {
type BindEmailIdentityRequest struct {
Email string `json:"email" binding:"required,email"`
VerifyCode string `json:"verify_code" binding:"required"`
- Password string `json:"password" binding:"required,min=6"`
+ Password string `json:"password" binding:"required"`
}
type SendEmailBindingCodeRequest struct {
diff --git a/backend/internal/handler/user_handler_test.go b/backend/internal/handler/user_handler_test.go
index 72b28293..24f715d4 100644
--- a/backend/internal/handler/user_handler_test.go
+++ b/backend/internal/handler/user_handler_test.go
@@ -422,6 +422,59 @@ func TestUserHandlerBindEmailIdentityReturnsProfileResponse(t *testing.T) {
require.True(t, resp.Data.EmailBound)
}
+func TestUserHandlerBindEmailIdentityRejectsWrongCurrentPasswordForBoundEmail(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ user := &service.User{
+ ID: 11,
+ Email: "current@example.com",
+ Username: "bound-user",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ }
+ require.NoError(t, user.SetPassword("current-password"))
+
+ repo := &userHandlerRepoStub{user: user}
+ emailCache := &userHandlerEmailCacheStub{
+ data: &service.VerificationCodeData{
+ Code: "123456",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
+ },
+ }
+ cfg := &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-secret",
+ ExpireHour: 1,
+ },
+ }
+ emailService := service.NewEmailService(nil, emailCache)
+ authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil)
+ handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil)
+
+ body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"wrong-password"}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/user/account-bindings/email", bytes.NewReader(body))
+ c.Request.Header.Set("Content-Type", "application/json")
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 11})
+
+ handler.BindEmailIdentity(c)
+
+ require.Equal(t, http.StatusBadRequest, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+ Reason string `json:"reason"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, http.StatusBadRequest, resp.Code)
+ require.Equal(t, "PASSWORD_INCORRECT", resp.Reason)
+ require.Equal(t, "current password is incorrect", resp.Message)
+ require.Equal(t, "current@example.com", repo.user.Email)
+}
+
func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) {
gin.SetMode(gin.TestMode)
diff --git a/backend/internal/service/auth_email_binding.go b/backend/internal/service/auth_email_binding.go
index 58f8e647..b060ab76 100644
--- a/backend/internal/service/auth_email_binding.go
+++ b/backend/internal/service/auth_email_binding.go
@@ -13,7 +13,8 @@ import (
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
-// BindEmailIdentity verifies and binds a local email/password identity to the current user.
+// BindEmailIdentity verifies and binds a local email/password identity to the
+// current user, or replaces the existing bound primary email.
func (s *AuthService) BindEmailIdentity(
ctx context.Context,
userID int64,
@@ -43,6 +44,13 @@ func (s *AuthService) BindEmailIdentity(
if err != nil {
return nil, err
}
+ firstRealEmailBind := !hasBindableEmailIdentitySubject(currentUser.Email)
+ if firstRealEmailBind && len(password) < 6 {
+ return nil, infraerrors.BadRequest("PASSWORD_TOO_SHORT", "password must be at least 6 characters")
+ }
+ if !firstRealEmailBind && !s.CheckPassword(password, currentUser.PasswordHash) {
+ return nil, ErrPasswordIncorrect
+ }
existingUser, err := s.userRepo.GetByEmail(ctx, normalizedEmail)
switch {
@@ -57,9 +65,8 @@ func (s *AuthService) BindEmailIdentity(
return nil, fmt.Errorf("hash password: %w", err)
}
- firstRealEmailBind := !hasBindableEmailIdentitySubject(currentUser.Email)
- if firstRealEmailBind && s.entClient != nil {
- if err := s.bindEmailIdentityWithDefaultsTx(ctx, currentUser, normalizedEmail, hashedPassword); err != nil {
+ if s.entClient != nil {
+ if err := s.updateBoundEmailIdentityTx(ctx, currentUser, normalizedEmail, hashedPassword, firstRealEmailBind); err != nil {
return nil, err
}
return currentUser, nil
@@ -137,14 +144,15 @@ func hasBindableEmailIdentitySubject(email string) bool {
return normalized != "" && !isReservedEmail(normalized)
}
-func (s *AuthService) bindEmailIdentityWithDefaultsTx(
+func (s *AuthService) updateBoundEmailIdentityTx(
ctx context.Context,
currentUser *User,
email string,
hashedPassword string,
+ applyFirstBindDefaults bool,
) error {
if tx := dbent.TxFromContext(ctx); tx != nil {
- return s.bindEmailIdentityWithDefaults(ctx, tx.Client(), currentUser, email, hashedPassword)
+ return s.updateBoundEmailIdentityWithClient(ctx, tx.Client(), currentUser, email, hashedPassword, applyFirstBindDefaults)
}
tx, err := s.entClient.Tx(ctx)
@@ -154,7 +162,7 @@ func (s *AuthService) bindEmailIdentityWithDefaultsTx(
defer func() { _ = tx.Rollback() }()
txCtx := dbent.NewTxContext(ctx, tx)
- if err := s.bindEmailIdentityWithDefaults(txCtx, tx.Client(), currentUser, email, hashedPassword); err != nil {
+ if err := s.updateBoundEmailIdentityWithClient(txCtx, tx.Client(), currentUser, email, hashedPassword, applyFirstBindDefaults); err != nil {
return err
}
if err := tx.Commit(); err != nil {
@@ -163,12 +171,13 @@ func (s *AuthService) bindEmailIdentityWithDefaultsTx(
return nil
}
-func (s *AuthService) bindEmailIdentityWithDefaults(
+func (s *AuthService) updateBoundEmailIdentityWithClient(
ctx context.Context,
client *dbent.Client,
currentUser *User,
email string,
hashedPassword string,
+ applyFirstBindDefaults bool,
) error {
if client == nil || currentUser == nil || currentUser.ID <= 0 {
return ErrServiceUnavailable
@@ -192,8 +201,10 @@ func (s *AuthService) bindEmailIdentityWithDefaults(
return ErrServiceUnavailable
}
- if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, currentUser.ID, "email"); err != nil {
- return fmt.Errorf("apply email first bind defaults: %w", err)
+ if applyFirstBindDefaults {
+ if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, currentUser.ID, "email"); err != nil {
+ return fmt.Errorf("apply email first bind defaults: %w", err)
+ }
}
updatedUser, err := client.User.Get(ctx, currentUser.ID)
diff --git a/backend/internal/service/auth_service_email_bind_test.go b/backend/internal/service/auth_service_email_bind_test.go
index fd5f499b..d32a4a40 100644
--- a/backend/internal/service/auth_service_email_bind_test.go
+++ b/backend/internal/service/auth_service_email_bind_test.go
@@ -285,6 +285,148 @@ func TestAuthServiceBindEmailIdentity_RejectsReservedEmail(t *testing.T) {
require.Nil(t, updatedUser)
}
+func TestAuthServiceBindEmailIdentity_ReplacesBoundEmailAndSkipsFirstBindDefaults(t *testing.T) {
+ assigner := &emailBindDefaultSubAssignerStub{}
+ cache := &emailBindCacheStub{
+ data: &service.VerificationCodeData{
+ Code: "123456",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
+ },
+ }
+ svc, _, client := newAuthServiceForEmailBind(t, map[string]string{
+ service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
+ service.SettingKeyAuthSourceDefaultEmailConcurrency: "4",
+ service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
+ service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
+ }, cache, assigner)
+
+ ctx := context.Background()
+ hashedPassword, err := svc.HashPassword("current-password")
+ require.NoError(t, err)
+
+ user, err := client.User.Create().
+ SetEmail("current@example.com").
+ SetUsername("bound-user").
+ SetPasswordHash(hashedPassword).
+ SetBalance(7.5).
+ SetConcurrency(3).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+ require.NoError(t, client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("email").
+ SetProviderKey("email").
+ SetProviderSubject("current@example.com").
+ SetVerifiedAt(time.Now().UTC()).
+ SetMetadata(map[string]any{"source": "test"}).
+ Exec(ctx))
+
+ updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, "new@example.com", "123456", "current-password")
+ require.NoError(t, err)
+ require.NotNil(t, updatedUser)
+ require.Equal(t, "new@example.com", updatedUser.Email)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, "new@example.com", storedUser.Email)
+ require.Equal(t, 7.5, storedUser.Balance)
+ require.Equal(t, 3, storedUser.Concurrency)
+ require.True(t, svc.CheckPassword("current-password", storedUser.PasswordHash))
+
+ newIdentityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("new@example.com"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, newIdentityCount)
+
+ oldIdentityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("current@example.com"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 0, oldIdentityCount)
+
+ require.Empty(t, assigner.calls)
+ require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
+}
+
+func TestAuthServiceBindEmailIdentity_RejectsWrongCurrentPasswordForBoundEmail(t *testing.T) {
+ cache := &emailBindCacheStub{
+ data: &service.VerificationCodeData{
+ Code: "123456",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
+ },
+ }
+ svc, _, client := newAuthServiceForEmailBind(t, nil, cache, nil)
+
+ ctx := context.Background()
+ hashedPassword, err := svc.HashPassword("current-password")
+ require.NoError(t, err)
+
+ user, err := client.User.Create().
+ SetEmail("current@example.com").
+ SetUsername("bound-user").
+ SetPasswordHash(hashedPassword).
+ SetBalance(1).
+ SetConcurrency(1).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+ require.NoError(t, client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("email").
+ SetProviderKey("email").
+ SetProviderSubject("current@example.com").
+ SetVerifiedAt(time.Now().UTC()).
+ SetMetadata(map[string]any{"source": "test"}).
+ Exec(ctx))
+
+ updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, "new@example.com", "123456", "wrong-password")
+ require.ErrorIs(t, err, service.ErrPasswordIncorrect)
+ require.Nil(t, updatedUser)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, "current@example.com", storedUser.Email)
+ require.True(t, svc.CheckPassword("current-password", storedUser.PasswordHash))
+
+ oldIdentityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("current@example.com"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, oldIdentityCount)
+
+ newIdentityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("new@example.com"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 0, newIdentityCount)
+}
+
type emailBindSettingRepoStub struct {
values map[string]string
}
diff --git a/frontend/src/components/user/profile/ProfileIdentityBindingsSection.vue b/frontend/src/components/user/profile/ProfileIdentityBindingsSection.vue
index 653b4e33..ee582a60 100644
--- a/frontend/src/components/user/profile/ProfileIdentityBindingsSection.vue
+++ b/frontend/src/components/user/profile/ProfileIdentityBindingsSection.vue
@@ -34,7 +34,7 @@
@@ -160,7 +160,7 @@ watch(
() => props.user,
(user) => {
localUser.value = null
- if (!user || getBindingStatusForUser(user, 'email')) {
+ if (!user) {
return
}
if (typeof user.email === 'string' && !user.email.endsWith('.invalid')) {
@@ -171,6 +171,17 @@ watch(
)
const currentUser = computed(() => localUser.value ?? props.user)
+const emailBound = computed(() => getBindingStatus('email'))
+const emailPasswordPlaceholder = computed(() =>
+ emailBound.value
+ ? t('profile.authBindings.replaceEmailPasswordPlaceholder')
+ : t('profile.authBindings.passwordPlaceholder')
+)
+const emailSubmitActionLabel = computed(() =>
+ emailBound.value
+ ? t('profile.authBindings.confirmEmailReplaceAction')
+ : t('profile.authBindings.confirmEmailBindAction')
+)
const wechatOAuthSettings = computed
(() => {
if (hasExplicitWeChatOAuthCapabilities(appStore.cachedPublicSettings)) {
@@ -286,7 +297,7 @@ function validateEmailBindingForm(requireCode: boolean): boolean {
appStore.showError(t('auth.passwordRequired'))
return false
}
- if (requireCode && emailBindingForm.password.length < 6) {
+ if (requireCode && !emailBound.value && emailBindingForm.password.length < 6) {
appStore.showError(t('auth.passwordMinLength'))
return false
}
@@ -321,10 +332,15 @@ async function bindEmail(): Promise {
verify_code: emailBindingForm.verifyCode,
password: emailBindingForm.password,
})
+ const replacingBoundEmail = emailBound.value
applyUpdatedUser(user)
emailBindingForm.verifyCode = ''
emailBindingForm.password = ''
- appStore.showSuccess(t('profile.authBindings.bindSuccess'))
+ appStore.showSuccess(
+ replacingBoundEmail
+ ? t('profile.authBindings.replaceSuccess')
+ : t('profile.authBindings.bindSuccess')
+ )
} catch (error) {
appStore.showError((error as { message?: string }).message || t('common.tryAgain'))
} finally {
diff --git a/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts b/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts
index c07acf18..8821cdc5 100644
--- a/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts
+++ b/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts
@@ -51,10 +51,14 @@ vi.mock('vue-i18n', async (importOriginal) => {
if (key === 'profile.authBindings.emailPlaceholder') return 'Email address'
if (key === 'profile.authBindings.codePlaceholder') return 'Verification code'
if (key === 'profile.authBindings.passwordPlaceholder') return 'Set password'
+ if (key === 'profile.authBindings.replaceEmailPasswordPlaceholder')
+ return 'Current password'
if (key === 'profile.authBindings.sendCodeAction') return 'Send code'
if (key === 'profile.authBindings.confirmEmailBindAction') return 'Bind email'
+ if (key === 'profile.authBindings.confirmEmailReplaceAction') return 'Replace primary email'
if (key === 'profile.authBindings.codeSentTo') return `Code sent to ${params?.email || ''}`.trim()
if (key === 'profile.authBindings.bindSuccess') return 'Bind success'
+ if (key === 'profile.authBindings.replaceSuccess') return 'Primary email updated'
return key
},
}),
@@ -324,4 +328,68 @@ describe('ProfileIdentityBindingsSection', () => {
expect(wrapper.get('[data-testid="profile-binding-email-status"]').text()).toBe('Not bound')
expect(wrapper.get('[data-testid="profile-binding-email-input"]').exists()).toBe(true)
})
+
+ it('keeps the email form available for replacing a bound primary email', async () => {
+ userApiMocks.sendEmailBindingCode.mockResolvedValue(undefined)
+ userApiMocks.bindEmailIdentity.mockResolvedValue(
+ createUser({
+ email: 'new@example.com',
+ email_bound: true,
+ auth_bindings: {
+ email: { bound: true },
+ },
+ })
+ )
+
+ const appStore = useAppStore()
+ const authStore = useAuthStore()
+ authStore.user = createUser({
+ email: 'current@example.com',
+ email_bound: true,
+ auth_bindings: {
+ email: { bound: true },
+ },
+ })
+ const showSuccessSpy = vi.spyOn(appStore, 'showSuccess')
+
+ const wrapper = mount(ProfileIdentityBindingsSection, {
+ global: {
+ plugins: [pinia],
+ },
+ props: {
+ user: authStore.user,
+ linuxdoEnabled: false,
+ oidcEnabled: false,
+ wechatEnabled: false,
+ },
+ })
+
+ expect(wrapper.get('[data-testid="profile-binding-email-status"]').text()).toBe('Bound')
+ expect(wrapper.get('[data-testid="profile-binding-email-input"]').exists()).toBe(true)
+ expect(wrapper.get('[data-testid="profile-binding-email-submit"]').text()).toBe(
+ 'Replace primary email'
+ )
+ expect(
+ (wrapper.get('[data-testid="profile-binding-email-password-input"]').element as HTMLInputElement)
+ .placeholder
+ ).toBe('Current password')
+
+ await wrapper.get('[data-testid="profile-binding-email-input"]').setValue('new@example.com')
+ await wrapper.get('[data-testid="profile-binding-email-send-code"]').trigger('click')
+ expect(userApiMocks.sendEmailBindingCode).toHaveBeenCalledWith('new@example.com')
+
+ await wrapper.get('[data-testid="profile-binding-email-code-input"]').setValue('123456')
+ await wrapper.get('[data-testid="profile-binding-email-password-input"]').setValue(
+ 'current-password'
+ )
+ await wrapper.get('[data-testid="profile-binding-email-submit"]').trigger('click')
+
+ expect(userApiMocks.bindEmailIdentity).toHaveBeenCalledWith({
+ email: 'new@example.com',
+ verify_code: '123456',
+ password: 'current-password',
+ })
+ expect(authStore.user?.email).toBe('new@example.com')
+ expect(showSuccessSpy).toHaveBeenCalledWith('Primary email updated')
+ })
})
diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts
index 2b41a3c3..345770a8 100644
--- a/frontend/src/i18n/locales/en.ts
+++ b/frontend/src/i18n/locales/en.ts
@@ -967,9 +967,12 @@ export default {
emailPlaceholder: 'Enter email address',
codePlaceholder: 'Enter verification code',
passwordPlaceholder: 'Set a login password',
+ replaceEmailPasswordPlaceholder: 'Enter current password',
sendCodeAction: 'Send code',
confirmEmailBindAction: 'Bind email',
+ confirmEmailReplaceAction: 'Replace primary email',
codeSentTo: 'Code sent to {email}',
+ replaceSuccess: 'Primary email updated',
status: {
bound: 'Bound',
notBound: 'Not bound',
diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts
index b60a69d6..6493ffe8 100644
--- a/frontend/src/i18n/locales/zh.ts
+++ b/frontend/src/i18n/locales/zh.ts
@@ -971,9 +971,12 @@ export default {
emailPlaceholder: '输入邮箱地址',
codePlaceholder: '输入验证码',
passwordPlaceholder: '设置登录密码',
+ replaceEmailPasswordPlaceholder: '输入当前密码',
sendCodeAction: '发送验证码',
confirmEmailBindAction: '绑定邮箱',
+ confirmEmailReplaceAction: '更换主邮箱',
codeSentTo: '验证码已发送到 {email}',
+ replaceSuccess: '主邮箱已更新',
status: {
bound: '已绑定',
notBound: '未绑定',
--
GitLab
From ace082066a14824462cb45386f43626ed7db9547 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 13:50:55 +0800
Subject: [PATCH 142/261] fix: honor ws transport when scheduler is disabled
---
.../service/openai_account_scheduler.go | 59 ++++++++++--
.../service/openai_account_scheduler_test.go | 92 +++++++++++++++++++
2 files changed, 143 insertions(+), 8 deletions(-)
diff --git a/backend/internal/service/openai_account_scheduler.go b/backend/internal/service/openai_account_scheduler.go
index 09e60220..38b92b47 100644
--- a/backend/internal/service/openai_account_scheduler.go
+++ b/backend/internal/service/openai_account_scheduler.go
@@ -767,14 +767,10 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
}
func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool {
- // HTTP 入站可回退到 HTTP 线路,不需要在账号选择阶段做传输协议强过滤。
- if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE {
- return true
- }
- if s == nil || s.service == nil || account == nil {
+ if s == nil || s.service == nil {
return false
}
- return s.service.getOpenAIWSProtocolResolver().Resolve(account).Transport == requiredTransport
+ return s.service.isOpenAIAccountTransportCompatible(account, requiredTransport)
}
func (s *defaultOpenAIAccountScheduler) ReportResult(accountID int64, success bool, firstTokenMs *int) {
@@ -899,9 +895,35 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler(
decision := OpenAIAccountScheduleDecision{}
scheduler := s.getOpenAIAccountScheduler(ctx)
if scheduler == nil {
- selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs)
decision.Layer = openAIAccountScheduleLayerLoadBalance
- return selection, decision, err
+ if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE {
+ selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs)
+ return selection, decision, err
+ }
+
+ effectiveExcludedIDs := cloneExcludedAccountIDs(excludedIDs)
+ for {
+ selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs)
+ if err != nil {
+ return nil, decision, err
+ }
+ if selection == nil || selection.Account == nil {
+ return selection, decision, nil
+ }
+ if s.isOpenAIAccountTransportCompatible(selection.Account, requiredTransport) {
+ return selection, decision, nil
+ }
+ if selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+ if effectiveExcludedIDs == nil {
+ effectiveExcludedIDs = make(map[int64]struct{})
+ }
+ if _, exists := effectiveExcludedIDs[selection.Account.ID]; exists {
+ return nil, decision, ErrNoAvailableAccounts
+ }
+ effectiveExcludedIDs[selection.Account.ID] = struct{}{}
+ }
}
var stickyAccountID int64
@@ -922,6 +944,27 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler(
})
}
+func cloneExcludedAccountIDs(excludedIDs map[int64]struct{}) map[int64]struct{} {
+ if len(excludedIDs) == 0 {
+ return nil
+ }
+ cloned := make(map[int64]struct{}, len(excludedIDs))
+ for id := range excludedIDs {
+ cloned[id] = struct{}{}
+ }
+ return cloned
+}
+
+func (s *OpenAIGatewayService) isOpenAIAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool {
+ if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE {
+ return true
+ }
+ if s == nil || account == nil {
+ return false
+ }
+ return s.getOpenAIWSProtocolResolver().Resolve(account).Transport == requiredTransport
+}
+
func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64, success bool, firstTokenMs *int) {
scheduler := s.getOpenAIAccountScheduler(context.Background())
if scheduler == nil {
diff --git a/backend/internal/service/openai_account_scheduler_test.go b/backend/internal/service/openai_account_scheduler_test.go
index a54f2614..b02370cb 100644
--- a/backend/internal/service/openai_account_scheduler_test.go
+++ b/backend/internal/service/openai_account_scheduler_test.go
@@ -298,6 +298,98 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabledUsesLega
require.False(t, decision.StickyPreviousHit)
}
+func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabled_RequiredWSV2_SkipsHTTPOnlyAccount(t *testing.T) {
+ resetOpenAIAdvancedSchedulerSettingCacheForTest()
+
+ ctx := context.Background()
+ groupID := int64(10108)
+ accounts := []Account{
+ {
+ ID: 36011,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Priority: 0,
+ },
+ {
+ ID: 36012,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Priority: 5,
+ Extra: map[string]any{
+ "openai_apikey_responses_websockets_v2_enabled": true,
+ },
+ },
+ }
+ cfg := newSchedulerTestOpenAIWSV2Config()
+ cfg.Gateway.Scheduling.LoadBatchEnabled = false
+ svc := &OpenAIGatewayService{
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
+ cache: &schedulerTestGatewayCache{},
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
+ }
+
+ selection, decision, err := svc.SelectAccountWithScheduler(
+ ctx,
+ &groupID,
+ "",
+ "",
+ "gpt-5.1",
+ nil,
+ OpenAIUpstreamTransportResponsesWebsocketV2,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ require.Equal(t, int64(36012), selection.Account.ID)
+ require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
+}
+
+func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabled_RequiredWSV2_NoAvailableAccount(t *testing.T) {
+ resetOpenAIAdvancedSchedulerSettingCacheForTest()
+
+ ctx := context.Background()
+ groupID := int64(10109)
+ accounts := []Account{
+ {
+ ID: 36021,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Priority: 0,
+ },
+ }
+ cfg := newSchedulerTestOpenAIWSV2Config()
+ cfg.Gateway.Scheduling.LoadBatchEnabled = false
+ svc := &OpenAIGatewayService{
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
+ cache: &schedulerTestGatewayCache{},
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
+ }
+
+ selection, decision, err := svc.SelectAccountWithScheduler(
+ ctx,
+ &groupID,
+ "",
+ "",
+ "gpt-5.1",
+ nil,
+ OpenAIUpstreamTransportResponsesWebsocketV2,
+ )
+ require.ErrorContains(t, err, "no available OpenAI accounts")
+ require.Nil(t, selection)
+ require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
+}
+
func TestOpenAIGatewayService_SelectAccountWithScheduler_EnabledUsesAdvancedPreviousResponseRouting(t *testing.T) {
resetOpenAIAdvancedSchedulerSettingCacheForTest()
--
GitLab
From 0fcddce69edd1fca78dbd5d1b1099d8a6fe4c3b8 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 13:53:12 +0800
Subject: [PATCH 143/261] fix: reject http responses continuation ids
---
.../handler/openai_gateway_handler.go | 9 ++-
.../handler/openai_gateway_handler_test.go | 58 +++++++++++++++++++
2 files changed, 65 insertions(+), 2 deletions(-)
diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go
index 5319b55d..43999a01 100644
--- a/backend/internal/handler/openai_gateway_handler.go
+++ b/backend/internal/handler/openai_gateway_handler.go
@@ -187,6 +187,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "previous_response_id must be a response.id (resp_*), not a message id")
return
}
+ reqLog.Warn("openai.request_validation_failed",
+ zap.String("reason", "previous_response_id_requires_wsv2"),
+ )
+ h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "previous_response_id is only supported on Responses WebSocket v2")
+ return
}
setOpsRequestContext(c, reqModel, reqStream, body)
@@ -856,7 +861,7 @@ func (h *OpenAIGatewayHandler) validateFunctionCallOutputRequest(c *gin.Context,
reqLog.Warn("openai.request_validation_failed",
zap.String("reason", "function_call_output_missing_call_id"),
)
- h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id")
+ h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id on HTTP requests; continuation via previous_response_id is only supported on Responses WebSocket v2")
return false
}
if validation.HasItemReferenceForAllCallIDs {
@@ -866,7 +871,7 @@ func (h *OpenAIGatewayHandler) validateFunctionCallOutputRequest(c *gin.Context,
reqLog.Warn("openai.request_validation_failed",
zap.String("reason", "function_call_output_missing_item_reference"),
)
- h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id")
+ h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id on HTTP requests; continuation via previous_response_id is only supported on Responses WebSocket v2")
return false
}
diff --git a/backend/internal/handler/openai_gateway_handler_test.go b/backend/internal/handler/openai_gateway_handler_test.go
index d299fb81..8ecee59a 100644
--- a/backend/internal/handler/openai_gateway_handler_test.go
+++ b/backend/internal/handler/openai_gateway_handler_test.go
@@ -494,6 +494,64 @@ func TestOpenAIResponses_RejectsMessageIDAsPreviousResponseID(t *testing.T) {
require.Contains(t, w.Body.String(), "previous_response_id must be a response.id")
}
+func TestOpenAIResponses_RejectsHTTPContinuationPreviousResponseID(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", strings.NewReader(
+ `{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_123456","input":[{"type":"input_text","text":"hello"}]}`,
+ ))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ groupID := int64(2)
+ c.Set(string(middleware.ContextKeyAPIKey), &service.APIKey{
+ ID: 101,
+ GroupID: &groupID,
+ User: &service.User{ID: 1},
+ })
+ c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
+ UserID: 1,
+ Concurrency: 1,
+ })
+
+ h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil)
+ h.Responses(c)
+
+ require.Equal(t, http.StatusBadRequest, w.Code)
+ require.Contains(t, w.Body.String(), "Responses WebSocket v2")
+ require.Contains(t, w.Body.String(), "previous_response_id")
+}
+
+func TestOpenAIResponses_FunctionCallOutputHTTPGuidanceDoesNotSuggestPreviousResponseReuse(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", strings.NewReader(
+ `{"model":"gpt-5.1","stream":false,"input":[{"type":"function_call_output","output":"{}"}]}`,
+ ))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ groupID := int64(2)
+ c.Set(string(middleware.ContextKeyAPIKey), &service.APIKey{
+ ID: 101,
+ GroupID: &groupID,
+ User: &service.User{ID: 1},
+ })
+ c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
+ UserID: 1,
+ Concurrency: 1,
+ })
+
+ h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil)
+ h.Responses(c)
+
+ require.Equal(t, http.StatusBadRequest, w.Code)
+ require.Contains(t, w.Body.String(), "Responses WebSocket v2")
+ require.NotContains(t, w.Body.String(), "reuse previous_response_id")
+}
+
func TestOpenAIResponsesWebSocket_SetsClientTransportWSWhenUpgradeValid(t *testing.T) {
gin.SetMode(gin.TestMode)
--
GitLab
From 62ff2d803f172defdc6648599edd7d3379539b35 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 13:56:02 +0800
Subject: [PATCH 144/261] fix: normalize chat completions service tier
---
.../openai_gateway_chat_completions.go | 42 +++++++++++++++++-
.../openai_gateway_chat_completions_test.go | 44 +++++++++++++++++++
2 files changed, 85 insertions(+), 1 deletion(-)
create mode 100644 backend/internal/service/openai_gateway_chat_completions_test.go
diff --git a/backend/internal/service/openai_gateway_chat_completions.go b/backend/internal/service/openai_gateway_chat_completions.go
index ac7d28a7..663066a3 100644
--- a/backend/internal/service/openai_gateway_chat_completions.go
+++ b/backend/internal/service/openai_gateway_chat_completions.go
@@ -107,11 +107,15 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
responsesBody = stripped
}
}
+ responsesBody, normalizedServiceTier, err := normalizeResponsesBodyServiceTier(responsesBody)
+ if err != nil {
+ return nil, fmt.Errorf("normalize service_tier in responses-shape body: %w", err)
+ }
// Minimal stub populated from the raw body so downstream billing
// propagation (ServiceTier, ReasoningEffort) keeps working.
responsesReq = &apicompat.ResponsesRequest{
Model: upstreamModel,
- ServiceTier: gjson.GetBytes(responsesBody, "service_tier").String(),
+ ServiceTier: normalizedServiceTier,
}
if effort := gjson.GetBytes(responsesBody, "reasoning.effort").String(); effort != "" {
responsesReq.Reasoning = &apicompat.ResponsesReasoning{Effort: effort}
@@ -124,6 +128,7 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
return nil, fmt.Errorf("convert chat completions to responses: %w", err)
}
responsesReq.Model = upstreamModel
+ normalizeResponsesRequestServiceTier(responsesReq)
responsesBody, err = json.Marshal(responsesReq)
if err != nil {
return nil, fmt.Errorf("marshal responses request: %w", err)
@@ -274,6 +279,41 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
return result, handleErr
}
+func normalizeResponsesRequestServiceTier(req *apicompat.ResponsesRequest) {
+ if req == nil {
+ return
+ }
+ req.ServiceTier = normalizedOpenAIServiceTierValue(req.ServiceTier)
+}
+
+func normalizeResponsesBodyServiceTier(body []byte) ([]byte, string, error) {
+ if len(body) == 0 {
+ return body, "", nil
+ }
+ rawServiceTier := gjson.GetBytes(body, "service_tier").String()
+ if rawServiceTier == "" {
+ return body, "", nil
+ }
+ normalizedServiceTier := normalizedOpenAIServiceTierValue(rawServiceTier)
+ if normalizedServiceTier == "" {
+ trimmed, err := sjson.DeleteBytes(body, "service_tier")
+ return trimmed, "", err
+ }
+ if normalizedServiceTier == rawServiceTier {
+ return body, normalizedServiceTier, nil
+ }
+ trimmed, err := sjson.SetBytes(body, "service_tier", normalizedServiceTier)
+ return trimmed, normalizedServiceTier, err
+}
+
+func normalizedOpenAIServiceTierValue(raw string) string {
+ normalized := normalizeOpenAIServiceTier(raw)
+ if normalized == nil {
+ return ""
+ }
+ return *normalized
+}
+
// handleChatCompletionsErrorResponse reads an upstream error and returns it in
// OpenAI Chat Completions error format.
func (s *OpenAIGatewayService) handleChatCompletionsErrorResponse(
diff --git a/backend/internal/service/openai_gateway_chat_completions_test.go b/backend/internal/service/openai_gateway_chat_completions_test.go
new file mode 100644
index 00000000..a00fb71c
--- /dev/null
+++ b/backend/internal/service/openai_gateway_chat_completions_test.go
@@ -0,0 +1,44 @@
+package service
+
+import (
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
+ "github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
+)
+
+func TestNormalizeResponsesRequestServiceTier(t *testing.T) {
+ t.Parallel()
+
+ req := &apicompat.ResponsesRequest{ServiceTier: " fast "}
+ normalizeResponsesRequestServiceTier(req)
+ require.Equal(t, "priority", req.ServiceTier)
+
+ req.ServiceTier = "flex"
+ normalizeResponsesRequestServiceTier(req)
+ require.Equal(t, "flex", req.ServiceTier)
+
+ req.ServiceTier = "default"
+ normalizeResponsesRequestServiceTier(req)
+ require.Empty(t, req.ServiceTier)
+}
+
+func TestNormalizeResponsesBodyServiceTier(t *testing.T) {
+ t.Parallel()
+
+ body, tier, err := normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"fast"}`))
+ require.NoError(t, err)
+ require.Equal(t, "priority", tier)
+ require.Equal(t, "priority", gjson.GetBytes(body, "service_tier").String())
+
+ body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"flex"}`))
+ require.NoError(t, err)
+ require.Equal(t, "flex", tier)
+ require.Equal(t, "flex", gjson.GetBytes(body, "service_tier").String())
+
+ body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"default"}`))
+ require.NoError(t, err)
+ require.Empty(t, tier)
+ require.False(t, gjson.GetBytes(body, "service_tier").Exists())
+}
--
GitLab
From 147ed42ad355ec4a100c0b29ba65ed3e1aeef233 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 14:10:30 +0800
Subject: [PATCH 145/261] fix: restrict payment return urls to internal result
page
---
backend/internal/service/payment_order.go | 2 +-
.../service/payment_resume_service.go | 46 +++++++++++++++++--
.../service/payment_resume_service_test.go | 24 ++++++++--
3 files changed, 63 insertions(+), 9 deletions(-)
diff --git a/backend/internal/service/payment_order.go b/backend/internal/service/payment_order.go
index 254af5fe..354f3cd1 100644
--- a/backend/internal/service/payment_order.go
+++ b/backend/internal/service/payment_order.go
@@ -350,7 +350,7 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen
}
subject := s.buildPaymentSubject(plan, limitAmount, cfg)
outTradeNo := order.OutTradeNo
- canonicalReturnURL, err := CanonicalizeReturnURL(req.ReturnURL)
+ canonicalReturnURL, err := CanonicalizeReturnURL(req.ReturnURL, req.SrcHost)
if err != nil {
return nil, err
}
diff --git a/backend/internal/service/payment_resume_service.go b/backend/internal/service/payment_resume_service.go
index 486aaac0..1806f5da 100644
--- a/backend/internal/service/payment_resume_service.go
+++ b/backend/internal/service/payment_resume_service.go
@@ -7,6 +7,7 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
+ "net"
"net/url"
"strconv"
"strings"
@@ -16,6 +17,8 @@ import (
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
+const paymentResultReturnPath = "/payment/result"
+
const (
PaymentSourceHostedRedirect = "hosted_redirect"
PaymentSourceWechatInAppResume = "wechat_in_app_resume"
@@ -215,7 +218,7 @@ func visibleMethodSourceSettingKey(method string) string {
}
}
-func CanonicalizeReturnURL(raw string) (string, error) {
+func CanonicalizeReturnURL(raw string, srcHost string) (string, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return "", nil
@@ -231,19 +234,29 @@ func CanonicalizeReturnURL(raw string) (string, error) {
if parsed.Path == "" {
parsed.Path = "/"
}
+ if parsed.Path != paymentResultReturnPath {
+ return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must target the canonical internal payment result page")
+ }
+ if !sameOriginHost(parsed.Host, srcHost) {
+ return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must use the same host as the current site")
+ }
return parsed.String(), nil
}
func buildPaymentReturnURL(base string, orderID int64, resumeToken string) (string, error) {
- canonical, err := CanonicalizeReturnURL(base)
- if err != nil || canonical == "" {
- return canonical, err
+ canonical := strings.TrimSpace(base)
+ if canonical == "" {
+ return "", nil
}
parsed, err := url.Parse(canonical)
if err != nil {
return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must be a valid URL")
}
+ if !parsed.IsAbs() || parsed.Host == "" {
+ return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must be a valid absolute URL")
+ }
+ parsed.Fragment = ""
query := parsed.Query()
if orderID > 0 {
@@ -258,6 +271,31 @@ func buildPaymentReturnURL(base string, orderID int64, resumeToken string) (stri
return parsed.String(), nil
}
+func sameOriginHost(returnURLHost string, requestHost string) bool {
+ returnHost := strings.TrimSpace(returnURLHost)
+ reqHost := strings.TrimSpace(requestHost)
+ if returnHost == "" || reqHost == "" {
+ return false
+ }
+ if strings.EqualFold(returnHost, reqHost) {
+ return true
+ }
+
+ returnName, returnPort := splitHostPortDefault(returnHost)
+ reqName, reqPort := splitHostPortDefault(reqHost)
+ if returnName == "" || reqName == "" {
+ return false
+ }
+ return strings.EqualFold(returnName, reqName) && returnPort == reqPort
+}
+
+func splitHostPortDefault(raw string) (string, string) {
+ if host, port, err := net.SplitHostPort(raw); err == nil {
+ return host, port
+ }
+ return raw, ""
+}
+
func (s *PaymentResumeService) CreateToken(claims ResumeTokenClaims) (string, error) {
if err := s.ensureSigningKey(); err != nil {
return "", err
diff --git a/backend/internal/service/payment_resume_service_test.go b/backend/internal/service/payment_resume_service_test.go
index 12d67be2..7fa8dca1 100644
--- a/backend/internal/service/payment_resume_service_test.go
+++ b/backend/internal/service/payment_resume_service_test.go
@@ -64,23 +64,39 @@ func TestNormalizePaymentSource(t *testing.T) {
func TestCanonicalizeReturnURL(t *testing.T) {
t.Parallel()
- got, err := CanonicalizeReturnURL("https://example.com/pay/result?b=2#a")
+ got, err := CanonicalizeReturnURL("https://example.com/payment/result?b=2#a", "example.com")
if err != nil {
t.Fatalf("CanonicalizeReturnURL returned error: %v", err)
}
- if got != "https://example.com/pay/result?b=2" {
- t.Fatalf("CanonicalizeReturnURL = %q, want %q", got, "https://example.com/pay/result?b=2")
+ if got != "https://example.com/payment/result?b=2" {
+ t.Fatalf("CanonicalizeReturnURL = %q, want %q", got, "https://example.com/payment/result?b=2")
}
}
func TestCanonicalizeReturnURLRejectsRelativeURL(t *testing.T) {
t.Parallel()
- if _, err := CanonicalizeReturnURL("/payment/result"); err == nil {
+ if _, err := CanonicalizeReturnURL("/payment/result", "example.com"); err == nil {
t.Fatal("CanonicalizeReturnURL should reject relative URLs")
}
}
+func TestCanonicalizeReturnURLRejectsExternalHost(t *testing.T) {
+ t.Parallel()
+
+ if _, err := CanonicalizeReturnURL("https://evil.example/payment/result", "app.example.com"); err == nil {
+ t.Fatal("CanonicalizeReturnURL should reject external hosts")
+ }
+}
+
+func TestCanonicalizeReturnURLRejectsNonCanonicalPath(t *testing.T) {
+ t.Parallel()
+
+ if _, err := CanonicalizeReturnURL("https://app.example.com/orders/42", "app.example.com"); err == nil {
+ t.Fatal("CanonicalizeReturnURL should reject non-canonical result paths")
+ }
+}
+
func TestBuildPaymentReturnURL(t *testing.T) {
t.Parallel()
--
GitLab
From 4a3652ec09d54fa0973ac93d7b3b501a550098aa Mon Sep 17 00:00:00 2001
From: erio
Date: Tue, 21 Apr 2026 14:10:53 +0800
Subject: [PATCH 146/261] refactor(channels): normalize at cache fill and
eliminate frontend as-cast
- channel.go: convert normalizeBillingModelSource into a (*Channel) method for entity cohesion
- channel_service.go: normalize in populateChannelCache so every cache-backed reader (gateway, billing, future endpoints) sees the default; drop the duplicate fallback inside resolveMapping
- table: tighten Row with status?: ChannelStatus / billing_model_source?: BillingModelSource, remove the [key: string]: unknown index signature
- admin view: drop the `as ChannelStatus` / `as BillingModelSource` assertions and add statusStyleOf / billingSourceLabelOf helpers with runtime fallback so unseen values render as "-" instead of crashing
---
backend/internal/service/channel.go | 12 +++++++++
backend/internal/service/channel_available.go | 2 +-
backend/internal/service/channel_service.go | 27 ++++++++-----------
.../channels/AvailableChannelsTable.vue | 6 ++++-
.../src/views/admin/AvailableChannelsView.vue | 21 ++++++++++++---
5 files changed, 46 insertions(+), 22 deletions(-)
diff --git a/backend/internal/service/channel.go b/backend/internal/service/channel.go
index dcb68dc5..fa1a87c1 100644
--- a/backend/internal/service/channel.go
+++ b/backend/internal/service/channel.go
@@ -111,6 +111,18 @@ func (c *Channel) IsActive() bool {
return c.Status == StatusActive
}
+// normalizeBillingModelSource 若 BillingModelSource 为空则回填默认值 ChannelMapped。
+// 作为 *Channel 的实体方法集中管理默认值,service 层只需在 Channel 进入内存
+// (缓存装填、repo 读出)时调用一次,下游读路径就无需重复兜底。
+func (c *Channel) normalizeBillingModelSource() {
+ if c == nil {
+ return
+ }
+ if c.BillingModelSource == "" {
+ c.BillingModelSource = BillingModelSourceChannelMapped
+ }
+}
+
// GetModelPricing 根据模型名查找渠道定价,未找到返回 nil。
// 精确匹配,大小写不敏感。返回值拷贝,不污染缓存。
func (c *Channel) GetModelPricing(model string) *ChannelModelPricing {
diff --git a/backend/internal/service/channel_available.go b/backend/internal/service/channel_available.go
index a162d81d..7f6d1e85 100644
--- a/backend/internal/service/channel_available.go
+++ b/backend/internal/service/channel_available.go
@@ -66,7 +66,7 @@ func (s *ChannelService) ListAvailable(ctx context.Context) ([]AvailableChannel,
}
sort.SliceStable(groups, func(i, j int) bool { return groups[i].Name < groups[j].Name })
- normalizeBillingModelSource(ch)
+ ch.normalizeBillingModelSource()
out = append(out, AvailableChannel{
ID: ch.ID,
diff --git a/backend/internal/service/channel_service.go b/backend/internal/service/channel_service.go
index 4f22e205..51984400 100644
--- a/backend/internal/service/channel_service.go
+++ b/backend/internal/service/channel_service.go
@@ -301,6 +301,9 @@ func (s *ChannelService) fetchChannelData(ctx context.Context) ([]Channel, map[i
}
// populateChannelCache 将渠道列表和分组平台映射填充到缓存快照中。
+// 装填时对每个 Channel 统一归一化 BillingModelSource,让缓存命中的所有下游
+// (gateway routing / billing / 未来任何 cache-backed 读路径)都拿到已归一化的实体,
+// 避免"每个出口各自记得 normalize"反模式。
func populateChannelCache(channels []Channel, groupPlatforms map[int64]string) *channelCache {
cache := newEmptyChannelCache()
cache.groupPlatform = groupPlatforms
@@ -308,6 +311,7 @@ func populateChannelCache(channels []Channel, groupPlatforms map[int64]string) *
cache.loadedAt = time.Now()
for i := range channels {
+ channels[i].normalizeBillingModelSource()
ch := &channels[i]
cache.byID[ch.ID] = ch
for _, gid := range ch.GroupIDs {
@@ -518,14 +522,13 @@ func (s *ChannelService) ResolveChannelMappingAndRestrict(ctx context.Context, g
// resolveMapping 基于已查找的渠道信息解析模型映射。
// antigravity 分组依次尝试所有匹配平台,确保跨平台同名映射各自独立。
func resolveMapping(lk *channelLookup, groupID int64, model string) ChannelMappingResult {
+ // lk.channel 来自已装填的缓存,BillingModelSource 已在 populateChannelCache 阶段归一化,
+ // 这里无需重复兜底。
result := ChannelMappingResult{
MappedModel: model,
ChannelID: lk.channel.ID,
BillingModelSource: lk.channel.BillingModelSource,
}
- if result.BillingModelSource == "" {
- result.BillingModelSource = BillingModelSourceChannelMapped
- }
modelLower := strings.ToLower(model)
if mapped := lookupMappingAcrossPlatforms(lk.cache, groupID, lk.platform, modelLower); mapped != "" {
@@ -686,7 +689,7 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
ApplyPricingToAccountStats: input.ApplyPricingToAccountStats,
AccountStatsPricingRules: input.AccountStatsPricingRules,
}
- normalizeBillingModelSource(channel)
+ channel.normalizeBillingModelSource()
if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil {
return nil, err
@@ -706,7 +709,7 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
if err != nil {
return nil, err
}
- normalizeBillingModelSource(created)
+ created.normalizeBillingModelSource()
return created, nil
}
@@ -717,18 +720,10 @@ func (s *ChannelService) GetByID(ctx context.Context, id int64) (*Channel, error
if err != nil {
return nil, err
}
- normalizeBillingModelSource(ch)
+ ch.normalizeBillingModelSource()
return ch, nil
}
-// normalizeBillingModelSource 若 BillingModelSource 为空则回填默认值 ChannelMapped。
-// 统一在 service 层完成,避免 handler 响应层重复兜底。
-func normalizeBillingModelSource(ch *Channel) {
- if ch != nil && ch.BillingModelSource == "" {
- ch.BillingModelSource = BillingModelSourceChannelMapped
- }
-}
-
// Update 更新渠道
func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChannelInput) (*Channel, error) {
channel, err := s.repo.GetByID(ctx, id)
@@ -762,7 +757,7 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan
if err != nil {
return nil, err
}
- normalizeBillingModelSource(updated)
+ updated.normalizeBillingModelSource()
return updated, nil
}
@@ -886,7 +881,7 @@ func (s *ChannelService) List(ctx context.Context, params pagination.PaginationP
return nil, nil, err
}
for i := range channels {
- normalizeBillingModelSource(&channels[i])
+ channels[i].normalizeBillingModelSource()
}
return channels, res, nil
}
diff --git a/frontend/src/components/channels/AvailableChannelsTable.vue b/frontend/src/components/channels/AvailableChannelsTable.vue
index 0bd19518..96aa82a9 100644
--- a/frontend/src/components/channels/AvailableChannelsTable.vue
+++ b/frontend/src/components/channels/AvailableChannelsTable.vue
@@ -62,6 +62,7 @@ import DataTable from '@/components/common/DataTable.vue'
import Icon from '@/components/icons/Icon.vue'
import SupportedModelChip from './SupportedModelChip.vue'
import type { UserSupportedModel } from '@/api/channels'
+import type { ChannelStatus, BillingModelSource } from '@/constants/channel'
interface GroupRef {
id: number
@@ -75,7 +76,10 @@ interface Row {
groups: GroupRef[]
// 复用 user 侧最小 DTO;admin 侧 SupportedModel 结构上是其超集,可直接传入。
supported_models: UserSupportedModel[]
- [key: string]: unknown
+ // admin 独有字段:用精确类型代替 `unknown`,让消费端无需 `as` 断言,
+ // 也能在后端新增 union 成员时让前端 Record 查表立刻出空而非崩溃。
+ status?: ChannelStatus
+ billing_model_source?: BillingModelSource
}
interface Column {
diff --git a/frontend/src/views/admin/AvailableChannelsView.vue b/frontend/src/views/admin/AvailableChannelsView.vue
index 74e85618..a9b2462f 100644
--- a/frontend/src/views/admin/AvailableChannelsView.vue
+++ b/frontend/src/views/admin/AvailableChannelsView.vue
@@ -46,16 +46,16 @@
- {{ statusStyles[row.status as ChannelStatus].label }}
+ {{ statusStyleOf(row.status).label }}
- {{ billingSourceLabels[row.billing_model_source as BillingModelSource] }}
+ {{ billingSourceLabelOf(row.billing_model_source) }}
@@ -101,7 +101,7 @@ const columns = computed(() => [
/**
* 显示样式:i18n label + Tailwind class,按 ChannelStatus 完整穷举。
- * 用 Record 强制未来新增状态时 TS 编译失败,避免遗漏分支。
+ * Record 键类型强制未来新增 ChannelStatus 成员时 TS 编译失败,避免遗漏分支。
*/
const statusStyles = computed>(() => ({
[CHANNEL_STATUS_ACTIVE]: {
@@ -124,6 +124,19 @@ const billingSourceLabels = computed>(() => (
[BILLING_MODEL_SOURCE_CHANNEL_MAPPED]: t('admin.availableChannels.billingSource.channel_mapped')
}))
+// 运行时兜底:即便 service 层归一化漏点或后端新增未同步的 enum 值传入,
+// 也不会触发 undefined.cls 崩溃;统一降级为 "-"。
+const DEFAULT_STATUS_STYLE = { label: '-', cls: '' }
+const DEFAULT_BILLING_LABEL = '-'
+
+function statusStyleOf(status: ChannelStatus | undefined): { label: string; cls: string } {
+ return status ? statusStyles.value[status] : DEFAULT_STATUS_STYLE
+}
+
+function billingSourceLabelOf(src: BillingModelSource | undefined): string {
+ return src ? billingSourceLabels.value[src] : DEFAULT_BILLING_LABEL
+}
+
const filteredChannels = computed(() => {
const q = searchQuery.value.trim().toLowerCase()
if (!q) return channels.value
--
GitLab
From 422f3449a23f937f147f6db7740a550f7ae6c5d0 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 14:54:42 +0800
Subject: [PATCH 147/261] chore: remove local docs from repo
---
.gitignore | 5 +-
docs/ADMIN_PAYMENT_INTEGRATION_API.md | 243 ------
docs/PAYMENT.md | 287 -------
docs/PAYMENT_CN.md | 287 -------
...-04-20-auth-identity-payment-foundation.md | 539 -------------
...auth-identity-payment-foundation-design.md | 763 ------------------
6 files changed, 1 insertion(+), 2123 deletions(-)
delete mode 100644 docs/ADMIN_PAYMENT_INTEGRATION_API.md
delete mode 100644 docs/PAYMENT.md
delete mode 100644 docs/PAYMENT_CN.md
delete mode 100644 docs/superpowers/plans/2026-04-20-auth-identity-payment-foundation.md
delete mode 100644 docs/superpowers/specs/2026-04-20-auth-identity-payment-foundation-design.md
diff --git a/.gitignore b/.gitignore
index 1a92ea3e..cf2bda9f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -126,12 +126,9 @@ backend/cmd/server/server
deploy/docker-compose.override.yml
.gocache/
vite.config.js
-docs/*
-!docs/PAYMENT.md
-!docs/PAYMENT_CN.md
+docs/
.serena/
.codex/
frontend/coverage/
aicodex
output/
-
diff --git a/docs/ADMIN_PAYMENT_INTEGRATION_API.md b/docs/ADMIN_PAYMENT_INTEGRATION_API.md
deleted file mode 100644
index f674f86c..00000000
--- a/docs/ADMIN_PAYMENT_INTEGRATION_API.md
+++ /dev/null
@@ -1,243 +0,0 @@
-# ADMIN_PAYMENT_INTEGRATION_API
-
-> 单文件中英双语文档 / Single-file bilingual documentation (Chinese + English)
-
----
-
-## 中文
-
-### 目标
-本文档用于对接外部支付系统(如 `sub2apipay`)与 Sub2API 的 Admin API,覆盖:
-- 支付成功后充值
-- 用户查询
-- 人工余额修正
-- 前端购买页参数透传
-
-### 基础地址
-- 生产:`https://`
-- Beta:`http://:8084`
-
-### 认证
-推荐使用:
-- `x-api-key: admin-<64hex>`
-- `Content-Type: application/json`
-- 幂等接口额外传:`Idempotency-Key`
-
-说明:管理员 JWT 也可访问 admin 路由,但服务间调用建议使用 Admin API Key。
-
-### 1) 一步完成创建并兑换
-`POST /api/v1/admin/redeem-codes/create-and-redeem`
-
-用途:原子完成“创建兑换码 + 兑换到指定用户”。
-
-请求头:
-- `x-api-key`
-- `Idempotency-Key`
-
-请求体示例:
-```json
-{
- "code": "s2p_cm1234567890",
- "type": "balance",
- "value": 100.0,
- "user_id": 123,
- "notes": "sub2apipay order: cm1234567890"
-}
-```
-
-幂等语义:
-- 同 `code` 且 `used_by` 一致:`200`
-- 同 `code` 但 `used_by` 不一致:`409`
-- 缺少 `Idempotency-Key`:`400`(`IDEMPOTENCY_KEY_REQUIRED`)
-
-curl 示例:
-```bash
-curl -X POST "${BASE}/api/v1/admin/redeem-codes/create-and-redeem" \
- -H "x-api-key: ${KEY}" \
- -H "Idempotency-Key: pay-cm1234567890-success" \
- -H "Content-Type: application/json" \
- -d '{
- "code":"s2p_cm1234567890",
- "type":"balance",
- "value":100.00,
- "user_id":123,
- "notes":"sub2apipay order: cm1234567890"
- }'
-```
-
-### 2) 查询用户(可选前置校验)
-`GET /api/v1/admin/users/:id`
-
-```bash
-curl -s "${BASE}/api/v1/admin/users/123" \
- -H "x-api-key: ${KEY}"
-```
-
-### 3) 余额调整(已有接口)
-`POST /api/v1/admin/users/:id/balance`
-
-用途:人工补偿 / 扣减,支持 `set` / `add` / `subtract`。
-
-请求体示例(扣减):
-```json
-{
- "balance": 100.0,
- "operation": "subtract",
- "notes": "manual correction"
-}
-```
-
-```bash
-curl -X POST "${BASE}/api/v1/admin/users/123/balance" \
- -H "x-api-key: ${KEY}" \
- -H "Idempotency-Key: balance-subtract-cm1234567890" \
- -H "Content-Type: application/json" \
- -d '{
- "balance":100.00,
- "operation":"subtract",
- "notes":"manual correction"
- }'
-```
-
-### 4) 购买页 / 自定义页面 URL Query 透传(iframe / 新窗口一致)
-当 Sub2API 打开 `purchase_subscription_url` 或用户侧自定义页面 iframe URL 时,会统一追加:
-- `user_id`
-- `token`
-- `theme`(`light` / `dark`)
-- `lang`(例如 `zh` / `en`,用于向嵌入页传递当前界面语言)
-- `ui_mode`(固定 `embedded`)
-
-示例:
-```text
-https://pay.example.com/pay?user_id=123&token=&theme=light&lang=zh&ui_mode=embedded
-```
-
-### 5) 失败处理建议
-- 支付成功与充值成功分状态落库
-- 回调验签成功后立即标记“支付成功”
-- 支付成功但充值失败的订单允许后续重试
-- 重试保持相同 `code`,并使用新的 `Idempotency-Key`
-
-### 6) `doc_url` 配置建议
-- 查看链接:`https://github.com/Wei-Shaw/sub2api/blob/main/ADMIN_PAYMENT_INTEGRATION_API.md`
-- 下载链接:`https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/ADMIN_PAYMENT_INTEGRATION_API.md`
-
----
-
-## English
-
-### Purpose
-This document describes the minimal Sub2API Admin API surface for external payment integrations (for example, `sub2apipay`), including:
-- Recharge after payment success
-- User lookup
-- Manual balance correction
-- Purchase page query parameter forwarding
-
-### Base URL
-- Production: `https://`
-- Beta: `http://:8084`
-
-### Authentication
-Recommended headers:
-- `x-api-key: admin-<64hex>`
-- `Content-Type: application/json`
-- `Idempotency-Key` for idempotent endpoints
-
-Note: Admin JWT can also access admin routes, but Admin API Key is recommended for server-to-server integration.
-
-### 1) Create and Redeem in one step
-`POST /api/v1/admin/redeem-codes/create-and-redeem`
-
-Use case: atomically create a redeem code and redeem it to a target user.
-
-Headers:
-- `x-api-key`
-- `Idempotency-Key`
-
-Request body:
-```json
-{
- "code": "s2p_cm1234567890",
- "type": "balance",
- "value": 100.0,
- "user_id": 123,
- "notes": "sub2apipay order: cm1234567890"
-}
-```
-
-Idempotency behavior:
-- Same `code` and same `used_by`: `200`
-- Same `code` but different `used_by`: `409`
-- Missing `Idempotency-Key`: `400` (`IDEMPOTENCY_KEY_REQUIRED`)
-
-curl example:
-```bash
-curl -X POST "${BASE}/api/v1/admin/redeem-codes/create-and-redeem" \
- -H "x-api-key: ${KEY}" \
- -H "Idempotency-Key: pay-cm1234567890-success" \
- -H "Content-Type: application/json" \
- -d '{
- "code":"s2p_cm1234567890",
- "type":"balance",
- "value":100.00,
- "user_id":123,
- "notes":"sub2apipay order: cm1234567890"
- }'
-```
-
-### 2) Query User (optional pre-check)
-`GET /api/v1/admin/users/:id`
-
-```bash
-curl -s "${BASE}/api/v1/admin/users/123" \
- -H "x-api-key: ${KEY}"
-```
-
-### 3) Balance Adjustment (existing API)
-`POST /api/v1/admin/users/:id/balance`
-
-Use case: manual correction with `set` / `add` / `subtract`.
-
-Request body example (`subtract`):
-```json
-{
- "balance": 100.0,
- "operation": "subtract",
- "notes": "manual correction"
-}
-```
-
-```bash
-curl -X POST "${BASE}/api/v1/admin/users/123/balance" \
- -H "x-api-key: ${KEY}" \
- -H "Idempotency-Key: balance-subtract-cm1234567890" \
- -H "Content-Type: application/json" \
- -d '{
- "balance":100.00,
- "operation":"subtract",
- "notes":"manual correction"
- }'
-```
-
-### 4) Purchase / Custom Page URL query forwarding (iframe and new tab)
-When Sub2API opens `purchase_subscription_url` or a user-facing custom page iframe URL, it appends:
-- `user_id`
-- `token`
-- `theme` (`light` / `dark`)
-- `lang` (for example `zh` / `en`, used to pass the current UI language to the embedded page)
-- `ui_mode` (fixed: `embedded`)
-
-Example:
-```text
-https://pay.example.com/pay?user_id=123&token=&theme=light&lang=zh&ui_mode=embedded
-```
-
-### 5) Failure handling recommendations
-- Persist payment success and recharge success as separate states
-- Mark payment as successful immediately after verified callback
-- Allow retry for orders with payment success but recharge failure
-- Keep the same `code` for retry, and use a new `Idempotency-Key`
-
-### 6) Recommended `doc_url`
-- View URL: `https://github.com/Wei-Shaw/sub2api/blob/main/ADMIN_PAYMENT_INTEGRATION_API.md`
-- Download URL: `https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/ADMIN_PAYMENT_INTEGRATION_API.md`
diff --git a/docs/PAYMENT.md b/docs/PAYMENT.md
deleted file mode 100644
index 9322f7bf..00000000
--- a/docs/PAYMENT.md
+++ /dev/null
@@ -1,287 +0,0 @@
-# Payment System Configuration Guide
-
-Sub2API has a built-in payment system that enables user self-service top-up without deploying a separate payment service.
-
----
-
-## Table of Contents
-
-- [Supported Payment Methods](#supported-payment-methods)
-- [Quick Start](#quick-start)
-- [System Settings](#system-settings)
-- [Provider Configuration](#provider-configuration)
-- [Provider Instance Management](#provider-instance-management)
-- [Webhook Configuration](#webhook-configuration)
-- [Payment Flow](#payment-flow)
-- [Migrating from Sub2ApiPay](#migrating-from-sub2apipay)
-
----
-
-## Supported Payment Methods
-
-| Provider | Payment Methods | Description |
-|----------|----------------|-------------|
-| **EasyPay** | Alipay, WeChat Pay | Third-party aggregation via EasyPay protocol |
-| **Alipay (Direct)** | Desktop QR code, mobile Alipay redirect | Direct integration with Alipay Open Platform, returning desktop QR codes and mobile WAP/app launch links |
-| **WeChat Pay (Direct)** | Native QR, H5, MP/JSAPI Pay | Direct integration with WeChat Pay APIv3 with environment-aware routing |
-| **Stripe** | Card, Alipay, WeChat Pay, Link, etc. | International payments, multi-currency support |
-
-> Alipay/WeChat Pay direct and EasyPay can both exist as backend provider instances, but the frontend always exposes only two visible buttons: `Alipay` and `WeChat Pay`. Admins choose exactly one source for each visible method: direct or EasyPay. Direct channels connect to payment APIs directly with lower fees; EasyPay aggregates through third-party platforms with easier setup.
-
-> **EasyPay Provider Recommendations**: Both options below are third-party aggregators compatible with the EasyPay protocol. Pick based on the funding channel and settlement currency you need:
->
-> - **Domestic channel / CNY settlement** — [ZPay](https://z-pay.cn/?uid=23808) (`https://z-pay.cn/?uid=23808`): direct integration with official Alipay / WeChat Pay APIs, fee **1.6%**; funds go straight to the merchant account with **T+1 automatic settlement**. Supports **individual users** (no business license required) with up to 10,000 CNY daily transactions; business-licensed accounts have no limit. Link contains the referral code of [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) original author [@touwaeriol](https://github.com/touwaeriol) — feel free to remove it.
-> - **International channel / USDT or USD settlement** — [Kyren Topup](https://kyren.top/?code=SUB2API) (`https://kyren.top/?code=SUB2API`): a ready-to-launch global payment stack for AI startups with WeChat Pay and Alipay support, local-currency checkout, and USD settlement. Fees: WeChat 2%, Alipay 2.5%; withdrawal 0.1% (min $40, max $150), settled in **USDT or USD**. No qualification review required — sign up and use immediately, making it the lowest barrier to entry. Withdrawal threshold is relatively high, recommended for users **who do not use domestic Chinese payment channels, cannot tolerate Stripe's 6%+ fees, have high transaction volume, and have USD or USDT channels to receive withdrawn funds**. Kyren Topup charges a $200 account opening fee; signing up via this link (which contains Sub2Api author [@Wei-Shaw](https://github.com/Wei-Shaw)'s referral code) **waives the opening fee**. Feel free to remove it if you prefer.
->
-> Please evaluate the security, reliability, and compliance of any third-party payment provider on your own — this project does not endorse or guarantee any of them.
-
----
-
-## Quick Start
-
-1. Go to Admin Dashboard → **Settings** → **Payment Settings** tab
-2. Enable **Payment**
-3. Configure basic parameters (amount range, timeout, etc.)
-4. Add at least one provider instance in **Provider Management**
-5. Users can now top up from the frontend
-
----
-
-## System Settings
-
-Configure the following in Admin Dashboard **Settings → Payment Settings**:
-
-### Basic Settings
-
-| Setting | Description | Default |
-|---------|-------------|---------|
-| **Enable Payment** | Enable or disable the payment system | Off |
-| **Product Name Prefix** | Prefix shown on payment page | - |
-| **Product Name Suffix** | Suffix (e.g., "Credits") | - |
-| **Minimum Amount** | Minimum single top-up amount | 1 |
-| **Maximum Amount** | Maximum single top-up amount (empty = unlimited) | - |
-| **Daily Limit** | Per-user daily cumulative limit (empty = unlimited) | - |
-| **Order Timeout** | Order timeout in minutes (minimum 1) | 30 |
-| **Max Pending Orders** | Maximum concurrent pending orders per user | 3 |
-| **Load Balance Strategy** | Strategy for selecting provider instances | Round Robin |
-
-### Frontend Visible Method Routing
-
-The current payment UX keeps the frontend method list unified and does not expose provider brands directly:
-
-- **Alipay**: when enabled, this button must be routed to either `Alipay (Direct)` or `EasyPay Alipay`
-- **WeChat Pay**: when enabled, this button must be routed to either `WeChat Pay (Direct)` or `EasyPay WeChat`
-- Each visible method can route to only one source at a time
-- If a visible method is enabled without a selected source, the frontend will not expose that method
-
-### Load Balance Strategies
-
-| Strategy | Description |
-|----------|-------------|
-| **Round Robin** | Distribute orders to instances in rotation |
-| **Least Amount** | Prefer instances with the lowest daily cumulative amount |
-
-### Cancel Rate Limiting
-
-Prevents users from repeatedly creating and canceling orders:
-
-| Setting | Description |
-|---------|-------------|
-| **Enable Limit** | Toggle |
-| **Window Mode** | Sliding / Fixed window |
-| **Time Window** | Window duration |
-| **Window Unit** | Minutes / Hours |
-| **Max Cancels** | Maximum cancellations allowed within the window |
-
-### Help Information
-
-| Setting | Description |
-|---------|-------------|
-| **Help Image** | Customer service QR code or help image (supports upload) |
-| **Help Text** | Instructions displayed on the payment page |
-
----
-
-## Provider Configuration
-
-Each provider type requires different credentials. Select the type when adding a new provider instance in **Provider Management → Add Provider**.
-
-> **Callback URLs are auto-generated**: When adding a provider, the Notify URL and Return URL are automatically constructed from your site domain. You only need to confirm the domain is correct.
-
-### EasyPay
-
-Compatible with any payment service that implements the EasyPay protocol.
-
-| Parameter | Description | Required |
-|-----------|-------------|----------|
-| **Merchant ID (PID)** | EasyPay merchant ID | Yes |
-| **Merchant Key (PKey)** | EasyPay merchant secret key | Yes |
-| **API Base URL** | EasyPay API base address | Yes |
-| **Alipay Channel ID** | Specify Alipay channel (optional) | No |
-| **WeChat Channel ID** | Specify WeChat channel (optional) | No |
-
-### Alipay (Direct)
-
-Direct integration with Alipay Open Platform. Desktop flows return a QR code for in-page display, while mobile flows return an Alipay WAP/app redirect URL.
-
-| Parameter | Description | Required |
-|-----------|-------------|----------|
-| **AppID** | Alipay application AppID | Yes |
-| **Private Key** | RSA2 application private key | Yes |
-| **Alipay Public Key** | Alipay public key | Yes |
-
-### WeChat Pay (Direct)
-
-Direct integration with WeChat Pay APIv3. Supports Native QR code payment, H5 payment, and MP/JSAPI payment inside the WeChat environment.
-
-| Parameter | Description | Required |
-|-----------|-------------|----------|
-| **AppID** | WeChat Pay AppID | Yes |
-| **Merchant ID (MchID)** | WeChat Pay merchant ID | Yes |
-| **Merchant API Private Key** | Merchant API private key (PEM format) | Yes |
-| **APIv3 Key** | 32-byte APIv3 key | Yes |
-| **WeChat Pay Public Key** | WeChat Pay public key (PEM format) | Yes |
-| **WeChat Pay Public Key ID** | WeChat Pay public key ID | Yes |
-| **Certificate Serial Number** | Merchant certificate serial number | Yes |
-
-### Stripe
-
-International payment platform supporting multiple payment methods and currencies.
-
-| Parameter | Description | Required |
-|-----------|-------------|----------|
-| **Secret Key** | Stripe secret key (`sk_live_...` or `sk_test_...`) | Yes |
-| **Publishable Key** | Stripe publishable key (`pk_live_...` or `pk_test_...`) | Yes |
-| **Webhook Secret** | Stripe Webhook signing secret (`whsec_...`) | Yes |
-
----
-
-## Provider Instance Management
-
-You can create **multiple instances** of the same provider type for load balancing and risk control:
-
-- **Multi-instance load balancing** — Distribute orders via round-robin or least-amount strategy
-- **Independent limits** — Each instance can have its own min/max amount and daily limit
-- **Independent toggle** — Enable/disable individual instances without affecting others
-- **Refund control** — Enable or disable refunds per instance
-- **Payment methods** — Each instance can support a subset of payment methods
-- **Ordering** — Drag to reorder instances
-
-### Instance Limit Configuration
-
-Each instance supports these limits:
-
-| Limit | Description |
-|-------|-------------|
-| **Minimum Amount** | Minimum order amount accepted by this instance |
-| **Maximum Amount** | Maximum order amount accepted by this instance |
-| **Daily Limit** | Daily cumulative transaction limit for this instance |
-
-> During load balancing, instances that exceed their limits are automatically skipped.
-
----
-
-## Webhook Configuration
-
-Payment callbacks are essential for the payment system to work correctly.
-
-### Callback URL Format
-
-When adding a provider, the system auto-generates callback URLs from your site domain:
-
-| Provider | Callback Path |
-|----------|-------------|
-| **EasyPay** | `https://your-domain.com/api/v1/payment/webhook/easypay` |
-| **Alipay (Direct)** | `https://your-domain.com/api/v1/payment/webhook/alipay` |
-| **WeChat Pay (Direct)** | `https://your-domain.com/api/v1/payment/webhook/wxpay` |
-| **Stripe** | `https://your-domain.com/api/v1/payment/webhook/stripe` |
-
-> Replace `your-domain.com` with your actual domain. For EasyPay / Alipay / WeChat Pay, the callback URL is auto-filled when adding the provider — no manual configuration needed.
-
-### Stripe Webhook Setup
-
-1. Log in to [Stripe Dashboard](https://dashboard.stripe.com/)
-2. Go to **Developers → Webhooks**
-3. Add an endpoint with the callback URL
-4. Subscribe to events: `payment_intent.succeeded`, `payment_intent.payment_failed`
-5. Copy the generated Webhook Secret (`whsec_...`) to your provider configuration
-
-### Important Notes
-
-- Callback URLs must use **HTTPS** (required by Stripe, strongly recommended for others)
-- Ensure your firewall allows callback requests from payment platforms
-- The system automatically verifies callback signatures to prevent forgery
-- Balance top-up is processed automatically upon successful payment — no manual intervention needed
-
----
-
-## Payment Flow
-
-```
-User selects amount and payment method
- │
- ▼
- Create Order (PENDING)
- ├─ Validate amount range, pending order count, daily limit
- ├─ Load balance to select provider instance
- └─ Call provider to get payment info
- │
- ▼
- User completes payment
- ├─ EasyPay → QR code / H5 redirect
- ├─ Alipay → Desktop QR / mobile Alipay redirect
- ├─ WeChat Pay → Desktop Native QR / non-WeChat H5 / in-WeChat JSAPI
- └─ Stripe → Payment Element (card/Alipay/WeChat/etc.)
- │
- ▼
- Webhook callback verified → Order PAID
- │
- ▼
- Auto top-up to user balance → Order COMPLETED
-```
-
-### Order Status Reference
-
-| Status | Description |
-|--------|-------------|
-| `PENDING` | Waiting for user to complete payment |
-| `PAID` | Payment confirmed, awaiting balance credit |
-| `COMPLETED` | Balance credited successfully |
-| `EXPIRED` | Timed out without payment |
-| `CANCELLED` | Cancelled by user |
-| `FAILED` | Balance credit failed, admin can retry |
-| `REFUND_REQUESTED` | Refund requested |
-| `REFUNDING` | Refund in progress |
-| `REFUNDED` | Refund completed |
-
-### Timeout and Fallback
-
-- Before marking an order as expired, the background job queries the upstream payment status first
-- If the user has actually paid but the callback was delayed, the system will reconcile automatically
-- The background job runs every 60 seconds to check for timed-out orders
-
----
-
-## Migrating from Sub2ApiPay
-
-If you previously used [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) as an external payment system, you can migrate to the built-in payment system:
-
-### Key Differences
-
-| Aspect | Sub2ApiPay | Built-in Payment |
-|--------|-----------|-----------------|
-| Deployment | Separate service (Next.js + PostgreSQL) | Built into Sub2API, no extra deployment |
-| Payment Methods | EasyPay, Alipay, WeChat, Stripe | Same |
-| Configuration | Environment variables + separate admin UI | Unified in Sub2API admin dashboard |
-| Top-up Integration | Via Admin API callback | Internal processing, more reliable |
-| Subscription Plans | Supported | Not yet (planned) |
-| Order Management | Separate admin interface | Integrated in Sub2API admin dashboard |
-
-### Migration Steps
-
-1. Enable payment in Sub2API admin dashboard and configure providers (use the same payment credentials)
-2. Update webhook callback URLs to Sub2API's callback endpoints
-3. Verify that new orders are processed correctly via built-in payment
-4. Decommission the Sub2ApiPay service
-
-> **Note**: Historical order data from Sub2ApiPay will not be automatically migrated. Keep Sub2ApiPay running for a while to access historical records.
diff --git a/docs/PAYMENT_CN.md b/docs/PAYMENT_CN.md
deleted file mode 100644
index 0fbc198a..00000000
--- a/docs/PAYMENT_CN.md
+++ /dev/null
@@ -1,287 +0,0 @@
-# 支付系统配置指南
-
-Sub2API 内置支付系统,支持用户自助充值,无需部署独立的支付服务。
-
----
-
-## 目录
-
-- [支持的支付方式](#支持的支付方式)
-- [快速开始](#快速开始)
-- [系统设置](#系统设置)
-- [服务商配置](#服务商配置)
-- [服务商实例管理](#服务商实例管理)
-- [Webhook 配置](#webhook-配置)
-- [支付流程](#支付流程)
-- [从 Sub2ApiPay 迁移](#从-sub2apipay-迁移)
-
----
-
-## 支持的支付方式
-
-| 服务商 | 支付方式 | 说明 |
-|--------|---------|------|
-| **EasyPay(易支付)** | 支付宝、微信支付 | 兼容易支付协议的第三方聚合支付 |
-| **支付宝官方** | 桌面二维码扫码、移动端支付宝跳转 | 直接对接支付宝开放平台,桌面端返回二维码,移动端返回 WAP/唤起链接 |
-| **微信官方** | Native 扫码、H5、公众号/JSAPI 支付 | 直接对接微信支付 APIv3,按终端环境自动分流 |
-| **Stripe** | 银行卡、支付宝、微信支付、Link 等 | 国际支付,支持多币种 |
-
-> 支付宝官方 / 微信官方与易支付可以同时作为后台服务商实例存在,但前台始终只展示 `支付宝`、`微信支付` 两个可见按钮。管理员需要分别为这两个按钮选择唯一支付来源:官方或易支付。官方渠道直接对接 API,资金直达商户账户,手续费更低;易支付通过第三方平台聚合,接入门槛更低。
-
-> **易支付服务商推荐**:以下两家均为兼容易支付协议的第三方聚合支付,按资金通道与结算方式选择:
->
-> - **国内渠道 / 人民币结算** — [ZPay](https://z-pay.cn/?uid=23808)(`https://z-pay.cn/?uid=23808`):支付宝 / 微信官方 API 直连,手续费 **1.6%**;资金直达商家账户,**T+1 自动到账**。支持**个人用户**(无营业执照)每日 1 万元以内交易;拥有营业执照则无限额。链接含 [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) 原作者 [@touwaeriol](https://github.com/touwaeriol) 的邀请码,介意可去掉。
-> - **国际渠道 / USDT 或美元结算** — [启润支付](https://kyren.top/?code=SUB2API)(`https://kyren.top/?code=SUB2API`):为 AI 项目提供低门槛国际收款通道,支持国际版微信支付与支付宝,本地货币支付、美元结算。手续费:微信 2%、支付宝 2.5%;提现 0.1%(最低 40 美元、最高 150 美元),以 **USDT 或美元**到账。无资质审核、注册即用,使用门槛最低;提现门槛略高,适合**不使用国内支付渠道、无法接受 Stripe 高达 6%+ 手续费、流水较大,且拥有美元或 USDT 渠道可接收提现资金**的用户。启润支付开户费 200 美元,通过本链接注册(含 Sub2Api 作者 [@Wei-Shaw](https://github.com/Wei-Shaw) 邀请码)可**免开户费**,介意可去掉。
->
-> 支付渠道的安全性、稳定性及合规性请自行鉴别,本项目不对任何第三方支付服务商做担保或背书。
-
----
-
-## 快速开始
-
-1. 进入管理后台 → **设置** → **支付设置** 标签页
-2. 开启 **启用支付**
-3. 配置基本参数(金额范围、超时时间等)
-4. 在 **服务商管理** 中添加至少一个服务商实例
-5. 用户即可在前端页面进行充值
-
----
-
-## 系统设置
-
-在管理后台 **设置 → 支付设置** 中配置以下参数:
-
-### 基本设置
-
-| 设置项 | 说明 | 默认值 |
-|--------|------|--------|
-| **启用支付** | 启用或禁用支付系统 | 关闭 |
-| **商品名前缀** | 支付页面显示的商品名前缀 | - |
-| **商品名后缀** | 商品名后缀(如"元") | - |
-| **最低金额** | 单笔最低充值金额 | 1 |
-| **最高金额** | 单笔最高充值金额(留空表示不限制) | - |
-| **每日限额** | 每用户每日累计充值上限(留空表示不限制) | - |
-| **订单超时时间** | 订单超时分钟数,至少 1 分钟 | 30 |
-| **最大待支付订单数** | 同一用户最大并行待支付订单数 | 3 |
-| **负载均衡策略** | 多服务商实例时的选择策略 | 轮询 |
-
-### 前台可见支付方式路由
-
-当前版本对用户统一展示支付方式,不区分官方渠道还是易支付:
-
-- **支付宝**:后台启用后,需要额外指定该按钮路由到 `支付宝官方` 或 `易支付支付宝`
-- **微信支付**:后台启用后,需要额外指定该按钮路由到 `微信官方` 或 `易支付微信`
-- 同一个可见支付方式在同一时刻只能路由到一个来源
-- 支付来源未选择时,即使对应按钮被开启,前台也不会暴露该支付方式
-
-### 负载均衡策略
-
-| 策略 | 说明 |
-|------|------|
-| **轮询(round-robin)** | 按顺序轮流分配到各服务商实例 |
-| **最少金额(least-amount)** | 优先分配到当日累计金额最少的实例 |
-
-### 取消频率限制
-
-防止用户频繁创建并取消订单:
-
-| 设置项 | 说明 |
-|--------|------|
-| **启用限制** | 开关 |
-| **窗口模式** | 滚动窗口 / 固定窗口 |
-| **时间窗口** | 窗口长度 |
-| **窗口单位** | 分钟 / 小时 |
-| **最大次数** | 窗口内允许的最大取消次数 |
-
-### 帮助信息
-
-| 设置项 | 说明 |
-|--------|------|
-| **帮助图片** | 充值页面显示的客服二维码等图片(支持上传) |
-| **帮助文本** | 充值页面显示的说明文字 |
-
----
-
-## 服务商配置
-
-每种服务商需要不同的凭证和参数。在 **服务商管理 → 添加服务商** 中选择类型后填写。
-
-> **回调地址自动生成**:添加服务商时,异步回调地址(Notify URL)和同步跳转地址(Return URL)由系统根据你的站点域名自动拼接,无需手动填写。管理员只需确认域名正确即可。
-
-### EasyPay(易支付)
-
-兼容任何 EasyPay 协议的支付服务商。
-
-| 参数 | 说明 | 必填 |
-|------|------|------|
-| **商户 ID(PID)** | EasyPay 商户 ID | 是 |
-| **商户密钥(PKey)** | EasyPay 商户密钥 | 是 |
-| **API 地址** | EasyPay API 基础地址 | 是 |
-| **支付宝通道 ID** | 指定支付宝通道(可选) | 否 |
-| **微信通道 ID** | 指定微信通道(可选) | 否 |
-
-### 支付宝官方
-
-直接对接支付宝开放平台。桌面端返回二维码供页面内展示和扫码,移动端返回支付宝手机网站支付跳转链接。
-
-| 参数 | 说明 | 必填 |
-|------|------|------|
-| **AppID** | 支付宝应用 AppID | 是 |
-| **应用私钥** | RSA2 应用私钥 | 是 |
-| **支付宝公钥** | 支付宝公钥 | 是 |
-
-### 微信官方
-
-直接对接微信支付 APIv3,支持 Native 扫码支付、H5 支付,以及在微信环境内的公众号/JSAPI 支付。
-
-| 参数 | 说明 | 必填 |
-|------|------|------|
-| **AppID** | 微信支付 AppID | 是 |
-| **商户号(MchID)** | 微信支付商户号 | 是 |
-| **商户 API 私钥** | 商户 API 私钥(PEM 格式) | 是 |
-| **APIv3 密钥** | 32 位 APIv3 密钥 | 是 |
-| **微信支付公钥** | 微信支付公钥(PEM 格式) | 是 |
-| **微信支付公钥 ID** | 微信支付公钥 ID | 是 |
-| **商户证书序列号** | 商户证书序列号 | 是 |
-
-### Stripe
-
-国际支付平台,支持多种支付方式和币种。
-
-| 参数 | 说明 | 必填 |
-|------|------|------|
-| **Secret Key** | Stripe 密钥(`sk_live_...` 或 `sk_test_...`) | 是 |
-| **Publishable Key** | Stripe 可公开密钥(`pk_live_...` 或 `pk_test_...`) | 是 |
-| **Webhook Secret** | Stripe Webhook 签名密钥(`whsec_...`) | 是 |
-
----
-
-## 服务商实例管理
-
-同一种服务商可以创建**多个实例**,实现负载均衡和风控:
-
-- **多实例负载均衡** — 按轮询或最少金额策略分流订单
-- **独立限额** — 每个实例可独立配置单笔最小/最大金额和每日限额
-- **独立启停** — 可单独启用/禁用某个实例,不影响其他实例
-- **退款控制** — 每个实例可单独开启或关闭退款功能
-- **支付方式** — 每个实例可选择支持的支付方式子集
-- **排序** — 拖拽调整实例顺序
-
-### 实例限额配置
-
-每个实例支持以下限额:
-
-| 限额项 | 说明 |
-|--------|------|
-| **单笔最小金额** | 该实例接受的最小订单金额 |
-| **单笔最大金额** | 该实例接受的最大订单金额 |
-| **每日限额** | 该实例每日累计交易上限 |
-
-> 负载均衡时,系统会自动跳过超出限额的实例。
-
----
-
-## Webhook 配置
-
-支付回调是支付系统的核心环节,必须正确配置:
-
-### 回调地址格式
-
-添加服务商时,系统会自动根据站点域名拼接回调地址,格式如下:
-
-| 服务商 | 回调路径 |
-|--------|---------|
-| **EasyPay** | `https://your-domain.com/api/v1/payment/webhook/easypay` |
-| **支付宝官方** | `https://your-domain.com/api/v1/payment/webhook/alipay` |
-| **微信官方** | `https://your-domain.com/api/v1/payment/webhook/wxpay` |
-| **Stripe** | `https://your-domain.com/api/v1/payment/webhook/stripe` |
-
-> 将 `your-domain.com` 替换为你的实际域名。EasyPay / 支付宝 / 微信的回调地址在添加服务商时自动填入,无需手动配置。
-
-### Stripe Webhook 设置
-
-1. 登录 [Stripe Dashboard](https://dashboard.stripe.com/)
-2. 进入 **Developers → Webhooks**
-3. 添加端点,填写回调地址
-4. 订阅事件:`payment_intent.succeeded`、`payment_intent.payment_failed`
-5. 将生成的 Webhook Secret(`whsec_...`)填入服务商配置
-
-### 注意事项
-
-- 回调地址必须是 **HTTPS**(Stripe 强制要求,其他服务商强烈推荐)
-- 确保服务器防火墙允许支付平台的回调请求
-- 系统会自动进行签名验证,防止伪造回调
-- 支付成功后自动完成余额充值,无需人工干预
-
----
-
-## 支付流程
-
-```
-用户选择充值金额和支付方式
- │
- ▼
- 创建订单 (PENDING)
- ├─ 校验金额范围、待支付订单数、每日限额
- ├─ 负载均衡选择服务商实例
- └─ 调用服务商获取支付信息
- │
- ▼
- 用户完成支付
- ├─ EasyPay → 扫码 / H5 跳转
- ├─ 支付宝官方 → 桌面二维码 / 移动端支付宝跳转
- ├─ 微信官方 → 桌面 Native 扫码 / 非微信 H5 / 微信内 JSAPI
- └─ Stripe → Payment Element(银行卡/支付宝/微信等)
- │
- ▼
- 支付回调验签 → 订单 PAID
- │
- ▼
- 自动充值到用户余额 → 订单 COMPLETED
-```
-
-### 订单状态说明
-
-| 状态 | 说明 |
-|------|------|
-| `PENDING` | 待支付,等待用户完成支付 |
-| `PAID` | 已支付,等待充值到账 |
-| `COMPLETED` | 已完成,余额已到账 |
-| `EXPIRED` | 已过期,超时未支付 |
-| `CANCELLED` | 已取消,用户主动取消 |
-| `FAILED` | 充值失败,可管理员重试 |
-| `REFUND_REQUESTED` | 已申请退款 |
-| `REFUNDING` | 退款处理中 |
-| `REFUNDED` | 已退款 |
-
-### 超时与兜底
-
-- 订单超时后,后台任务会先查询上游支付状态再标记过期
-- 如果用户实际已支付但回调延迟,系统会通过查询补单
-- 后台任务每 60 秒执行一次超时检查
-
----
-
-## 从 Sub2ApiPay 迁移
-
-如果你之前使用 [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) 作为外部支付系统,现在可以迁移到内置支付:
-
-### 主要差异
-
-| 对比项 | Sub2ApiPay | 内置支付 |
-|--------|-----------|---------|
-| 部署方式 | 独立服务(Next.js + PostgreSQL) | 内置于 Sub2API,无需额外部署 |
-| 支付方式 | EasyPay、支付宝、微信、Stripe | 相同 |
-| 配置方式 | 环境变量 + 独立管理后台 | Sub2API 管理后台内统一配置 |
-| 充值对接 | 通过 Admin API 回调 | 内部直接处理,更可靠 |
-| 订阅套餐 | 支持 | 暂不支持(计划中) |
-| 订单管理 | 独立管理界面 | 集成在 Sub2API 管理后台 |
-
-### 迁移步骤
-
-1. 在 Sub2API 管理后台启用支付并配置服务商(使用相同的支付凭证)
-2. 更新 Webhook 回调地址为 Sub2API 的回调地址
-3. 确认新订单通过内置支付正常处理
-4. 停用 Sub2ApiPay 服务
-
-> **注意**:Sub2ApiPay 中的历史订单数据不会自动迁移。建议保留 Sub2ApiPay 一段时间以便查询历史记录。
diff --git a/docs/superpowers/plans/2026-04-20-auth-identity-payment-foundation.md b/docs/superpowers/plans/2026-04-20-auth-identity-payment-foundation.md
deleted file mode 100644
index 2d44e058..00000000
--- a/docs/superpowers/plans/2026-04-20-auth-identity-payment-foundation.md
+++ /dev/null
@@ -1,539 +0,0 @@
-# Auth Identity Payment Foundation Implementation Plan
-
-> **For agentic workers:** REQUIRED SUB-SKILL: Use `superpowers:subagent-driven-development` (recommended) or `superpowers:executing-plans` to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
-
-**Goal:** Rebuild the auth identity, profile binding, payment routing, and OpenAI advanced scheduler foundation on top of a clean `origin/main` branch while preserving historical compatibility for existing email users, existing LinuxDo users, historical LinuxDo/WeChat/OIDC synthetic-email users, and historical WeChat `openid`-only records.
-**Architecture:** A unified identity foundation centered on durable provider subjects (`email`, `linuxdo`, `oidc`, `wechat`) and transactional pending-auth sessions; backend-owned payment source routing behind stable frontend methods (`alipay`, `wxpay`); compatibility-first migration/backfill before feature enablement.
-**Tech Stack:** Go, Gin, Ent, PostgreSQL, Redis, Vue 3, Pinia, TypeScript, Vitest, pnpm.
-
----
-
-## Non-Negotiable Product Rules
-
-- [ ] Preserve login continuity for existing email users, existing LinuxDo users, and historically migrated third-party users.
-- [ ] During migration, backfill historical LinuxDo/WeChat/OIDC synthetic-email users into explicit third-party identities before first post-upgrade login whenever deterministic recovery is possible.
-- [ ] During migration, surface historical WeChat `openid`-only records through explicit migration reports and remediation rules; do not silently reinterpret them as valid canonical identities.
-- [ ] Keep existing email login and add third-party login/bind for `linuxdo`, `oidc`, and `wechat`.
-- [ ] On first third-party login:
- - identity exists: direct login.
- - identity does not exist: start pending-auth flow.
- - local email binding is required only when system config says so.
- - upstream provider email verification never counts as local email verification.
-- [ ] When user-entered and locally verified email already exists:
- - offer bind-existing-account after local re-authentication.
- - offer change-email-and-create-new-account.
- - when email binding is mandatory, do not allow bypass without changing to another email.
-- [ ] On first third-party login or first third-party bind, provider nickname/avatar must be presented as independent replace options for the current nickname and avatar. They are not auto-applied.
-- [ ] Source-specific initial grants must support per-source defaults for balance, concurrency, and subscriptions.
-- [ ] Default grant timing: on successful new-account creation.
-- [ ] Optional grant timing: on first successful bind for the configured source.
-- [ ] Migration/backfill must never trigger first-bind or first-signup grants retroactively.
-- [ ] Avatar profile supports:
- - direct URL storage.
- - image data URL upload compressed to `<=100KB` before storing in DB.
- - explicit delete.
-- [ ] Admin user management must expose and sort by `last_login_at` and `last_active_at`.
-- [ ] WeChat login rules:
- - WeChat environment uses MP login.
- - non-WeChat browser uses Open/QR login.
- - canonical identity uses `unionid`.
- - when `unionid` is unavailable, fail the login/bind flow under the approved option-1 policy.
-- [ ] OIDC rules:
- - browser authorization-code flow always uses PKCE `S256`.
- - discovery issuer and ID token `iss` must match exactly.
- - `userinfo.sub` must match ID token `sub` when UserInfo is used.
- - upstream `email_verified` does not satisfy local email verification.
-- [ ] Payment UI rules:
- - user-facing methods stay `支付宝` and `微信支付`.
- - backend decides whether each method routes to official provider instance or EasyPay.
- - at runtime, each visible method may only have one active source.
-- [ ] Alipay rules:
- - PC: in-page QR.
- - mobile browser: jump to Alipay payment.
-- [ ] WeChat Pay rules:
- - PC: in-page QR.
- - WeChat H5: MP/JSAPI first, fallback to H5 pay.
- - non-WeChat H5: H5 pay, or prompt to open in WeChat when unavailable.
-- [ ] Payment success pages are informational only; actual fulfillment depends on webhook or server-side reconciliation.
-- [ ] WeChat in-app payment requiring `openid` must use a dedicated server-backed payment OAuth resume flow rather than frontend-only recovery state.
-- [ ] OpenAI advanced scheduler is available but default-disabled.
-
-## Hard Technical Constraints From Audit
-
-- [ ] Browser-based third-party auth must use Authorization Code + PKCE `S256`.
-- [ ] PKCE must not be admin-configurable off for browser authorization-code providers.
-- [ ] OIDC identity primary key must be `(issuer, subject)`, not email.
-- [ ] Email equality must never auto-link accounts.
-- [ ] Bind-existing-account must require explicit local re-authentication and TOTP verification when enabled.
-- [ ] Bind-current-user must originate from an already-authenticated local user and preserve explicit bind intent across callback completion.
-- [ ] OAuth redirect URI must be fixed server config, exact-match, and never derived from user input.
-- [ ] User-supplied redirect may only choose a normalized same-origin internal route after completion.
-- [ ] WeChat canonical identity must be `unionid`; `openid` remains channel/app-scoped support data only.
-- [ ] Every canonical identity uniqueness rule must include provider namespace (`provider_key`) consistently.
-- [ ] Callback completion must use backend session completion or a one-time opaque exchange code that is short-lived, one-time, browser-session-bound, `POST`-redeemed, and unusable as a bearer token.
-- [ ] Every payment order must snapshot the selected provider instance plus the order-time verification inputs required for callback verification, reconciliation, refund, and audit.
-- [ ] Frontend must not receive first-party bearer tokens through callback URL fragments in the rebuilt flow.
-- [ ] Public payment result polling must not expose order data by raw `out_trade_no` alone; use authenticated lookup or signed opaque result token.
-- [ ] WeChat Pay webhook handling must verify signature, decrypt payload, and compare `appid`, `mchid`, `out_trade_no`, `amount`, `currency`, and provider trade state against the order snapshot before fulfillment.
-
-## Baseline Notes
-
-- [ ] Current clean branch head when this plan was written: `721d7ab3`.
-- [ ] Baseline backend verification on clean `origin/main`: `cd backend && go test ./...` passes.
-- [ ] Baseline frontend verification on clean `origin/main`: `cd frontend && pnpm test:run` currently fails in unrelated existing suites. New work must add targeted tests and avoid claiming full frontend green until those baseline failures are addressed separately.
-- [ ] Existing migration directory currently ends at `107_*`; this rebuild reserves `108` through `111`.
-
-## Target File Map
-
-### New backend migrations
-
-- [ ] `backend/migrations/108_auth_identity_foundation_core.sql`
-- [ ] `backend/migrations/109_auth_identity_compat_backfill.sql`
-- [ ] `backend/migrations/110_pending_auth_and_provider_default_grants.sql`
-- [ ] `backend/migrations/111_payment_routing_and_scheduler_flags.sql`
-
-### New or rebuilt Ent schema
-
-- [ ] `backend/ent/schema/auth_identity.go`
-- [ ] `backend/ent/schema/auth_identity_channel.go`
-- [ ] `backend/ent/schema/pending_auth_session.go`
-- [ ] `backend/ent/schema/identity_adoption_decision.go`
-
-### New or rebuilt backend repositories/services/handlers
-
-- [ ] `backend/internal/repository/user_profile_identity_repo.go`
-- [ ] `backend/internal/repository/user_profile_identity_repo_contract_test.go`
-- [ ] `backend/internal/repository/auth_identity_migration_report.go`
-- [ ] `backend/internal/service/auth_identity_flow.go`
-- [ ] `backend/internal/service/auth_identity_flow_test.go`
-- [ ] `backend/internal/service/auth_pending_identity_service.go`
-- [ ] `backend/internal/service/auth_pending_identity_service_test.go`
-- [ ] `backend/internal/service/payment_config_service.go`
-- [ ] `backend/internal/service/payment_order.go`
-- [ ] `backend/internal/service/payment_order_lifecycle.go`
-- [ ] `backend/internal/service/payment_fulfillment.go`
-- [ ] `backend/internal/service/payment_resume_service.go`
-- [ ] `backend/internal/service/payment_resume_service_test.go`
-- [ ] `backend/internal/service/openai_account_scheduler.go`
-- [ ] `backend/internal/handler/auth_pending_identity_flow.go`
-- [ ] `backend/internal/handler/auth_linuxdo_oauth.go`
-- [ ] `backend/internal/handler/auth_oidc_oauth.go`
-- [ ] `backend/internal/handler/auth_wechat_oauth.go`
-- [ ] `backend/internal/handler/auth_handler.go`
-- [ ] `backend/internal/handler/user_handler.go`
-- [ ] `backend/internal/handler/payment_handler.go`
-- [ ] `backend/internal/handler/payment_webhook_handler.go`
-- [ ] `backend/internal/handler/admin/user_handler.go`
-- [ ] `backend/internal/handler/admin/setting_handler.go`
-
-### New or rebuilt frontend API/store/views/components
-
-- [ ] `frontend/src/api/auth.ts`
-- [ ] `frontend/src/api/user.ts`
-- [ ] `frontend/src/api/payment.ts`
-- [ ] `frontend/src/api/admin/settings.ts`
-- [ ] `frontend/src/api/admin/users.ts`
-- [ ] `frontend/src/stores/auth.ts`
-- [ ] `frontend/src/stores/payment.ts`
-- [ ] `frontend/src/components/auth/ThirdPartyAuthCallbackFlow.vue`
-- [ ] `frontend/src/components/auth/LinuxDoOAuthSection.vue`
-- [ ] `frontend/src/components/auth/OidcOAuthSection.vue`
-- [ ] `frontend/src/components/auth/WechatOAuthSection.vue`
-- [ ] `frontend/src/components/user/profile/ProfileAccountBindingsCard.vue`
-- [ ] `frontend/src/components/user/profile/ProfileInfoCard.vue`
-- [ ] `frontend/src/views/auth/LinuxDoCallbackView.vue`
-- [ ] `frontend/src/views/auth/OidcCallbackView.vue`
-- [ ] `frontend/src/views/auth/WechatCallbackView.vue`
-- [ ] `frontend/src/views/user/ProfileView.vue`
-- [ ] `frontend/src/views/user/PaymentView.vue`
-- [ ] `frontend/src/views/user/PaymentQRCodeView.vue`
-- [ ] `frontend/src/views/user/PaymentResultView.vue`
-
-## Phase 1: Migration And Compatibility Foundation
-
-### Task 1. Create core identity schema migration
-
-- [ ] Implement `backend/migrations/108_auth_identity_foundation_core.sql` with:
- - `auth_identities`
- - `auth_identity_channels`
- - `pending_auth_sessions`
- - `identity_adoption_decisions`
- - `users.last_login_at`
- - `users.last_active_at`
- - grant-tracking columns/tables required to prevent double-award
-- [ ] Add uniqueness/index rules:
- - one canonical identity per `(provider, provider_key, provider_subject)`
- - one channel record per `(provider, provider_channel, provider_app_id, provider_channel_subject)`
- - one adoption decision per pending session
-- [ ] Model `pending_auth_sessions` so immutable upstream claims and mutable local flow state are stored separately; do not reintroduce a mixed `metadata` catch-all.
-- [ ] Preserve null-safe compatibility defaults so historical rows remain readable before backfill finishes.
-- [ ] Add explicit rollback blocks only where safe; never repeat the destructive pattern observed in old `112_update_pending_auth_sessions.sql`.
-
-### Task 2. Materialize historical identities before runtime
-
-- [ ] Implement `backend/migrations/109_auth_identity_compat_backfill.sql` to backfill:
- - existing email users into `auth_identities(provider=email, provider_subject=normalized_email)`
- - historical LinuxDo users into `auth_identities(provider=linuxdo, provider_subject=linuxdo_subject)`
- - historical synthetic-email LinuxDo users into explicit LinuxDo identity rows by parsing legacy email mode and legacy provider metadata
- - historical synthetic-email WeChat users into explicit WeChat identities where `unionid` or equivalent deterministic provider identity is recoverable
- - historical synthetic-email OIDC users into explicit OIDC identities where deterministic provider identity is recoverable
- - profile/channel rows from historical `user_external_identities`-style data when present in upgraded databases
-- [ ] Write migration report output in `backend/internal/repository/auth_identity_migration_report.go` so production can inspect unmatched rows, `openid`-only WeChat rows, and non-deterministic synthetic-email rows instead of silently skipping them.
-- [ ] Set `signup_source` and provider provenance when recoverable from historical data. Do not flatten everything to `email`.
-
-### Task 3. Provider default grant and scheduler config migration
-
-- [ ] Implement `backend/migrations/110_pending_auth_and_provider_default_grants.sql` for:
- - provider-specific initial balance/concurrency/subscription defaults
- - grant timing flags: `on_signup`, optional `on_first_bind`
- - email-required-on-third-party-signup flags
- - profile avatar storage columns/settings
-- [ ] Implement `backend/migrations/111_payment_routing_and_scheduler_flags.sql` for:
- - stable payment method to provider-instance routing
- - visible-method normalization from historical `supported_types`, `payment_mode`, and legacy aliases such as `wxpay_direct`
- - admin exclusivity flags for `alipay` and `wxpay`
- - advanced scheduler enable flag defaulting to disabled
-
-### Task 4. Generate Ent and compile migration-safe model layer
-
-- [ ] Add the schema definitions in:
- - `backend/ent/schema/auth_identity.go`
- - `backend/ent/schema/auth_identity_channel.go`
- - `backend/ent/schema/pending_auth_session.go`
- - `backend/ent/schema/identity_adoption_decision.go`
-- [ ] Run:
- ```bash
- cd backend
- go generate ./ent
- ```
-- [ ] Compile after generation:
- ```bash
- cd backend
- go test ./... -run '^$'
- ```
-- [ ] Commit checkpoint:
- ```bash
- git add backend/migrations backend/ent/schema backend/ent
- git commit -m "feat: add auth identity foundation schema"
- ```
-
-## Phase 2: Backend Identity Flow Rebuild
-
-### Task 5. Build a single repository contract for identity lookups and grants
-
-- [ ] Implement `backend/internal/repository/user_profile_identity_repo.go` with transactional helpers for:
- - get user by canonical identity
- - get user by channel identity
- - create canonical + channel identity together
- - bind identity to existing user after verified re-auth
- - record one-time provider grant award
- - record adoption preference decisions
- - update `last_login_at` and `last_active_at`
-- [ ] Add repository contract coverage in `backend/internal/repository/user_profile_identity_repo_contract_test.go`.
-- [ ] Enforce dual-write for email registration/login so `users.email` and `auth_identities(provider=email, ...)` stay consistent from this phase onward.
-- [ ] Add repository coverage proving `last_login_at` and `last_active_at` use the required field names and are not silently replaced by derived `last_used_at` logic.
-
-### Task 6. Rebuild transactional pending-auth service
-
-- [ ] Implement `backend/internal/service/auth_pending_identity_service.go` and tests to own these flows:
- - create pending session from third-party callback
- - verify local email code
- - create new account from pending session with correct `signup_source`
- - bind pending identity to existing account after password/TOTP re-auth
- - apply configured provider defaults on the correct trigger only once
- - store provider nickname/avatar candidates and user opt-in replacement decisions independently
-- [ ] Implement callback completion so pending auth can finish through backend session completion or a one-time exchange code:
- - short TTL
- - one-time use
- - browser-session binding
- - `POST` redemption only
- - safe mixed-version bridge to legacy pending-token aliases during rollout
-- [ ] Keep pending session payload normalized:
- - provider identity fields live in typed columns/JSON structure
- - mutable local progression lives separately from immutable upstream claims
- - avoid the old branch’s mixed `metadata` and `upstream_identity_payload` ambiguity
-- [ ] Do not call plain email registration helpers from this flow. The old feature branch bug where pending third-party signup fell back to `RegisterWithVerification` must not reappear.
-
-### Task 7. Rebuild provider callback adapters
-
-- [ ] Refactor these handlers to thin adapters over the shared pending-auth service:
- - `backend/internal/handler/auth_linuxdo_oauth.go`
- - `backend/internal/handler/auth_oidc_oauth.go`
- - `backend/internal/handler/auth_wechat_oauth.go`
-- [ ] For OIDC:
- - require PKCE `S256`, `state`, and `nonce`
- - validate discovery issuer, `iss`, `aud`, optional `azp`, `exp`, and `nonce`
- - verify `userinfo.sub == id_token.sub` when UserInfo is used
- - persist canonical identity as `(issuer, sub)`
-- [ ] For WeChat:
- - MP flow in WeChat UA
- - Open/QR flow outside WeChat UA
- - website login uses authorization-code flow and persists channel/app binding
- - persist channel identity by `(channel, appid, openid)`
- - persist canonical identity by `unionid`
- - hard-fail when `unionid` is absent under the approved product policy
-- [ ] Replace callback URL fragment token delivery with backend session completion or one-time exchange code consumed by `frontend/src/stores/auth.ts`.
-
-### Task 8. Rebuild auth endpoints and profile binding endpoints
-
-- [ ] Implement `backend/internal/handler/auth_pending_identity_flow.go` for:
- - fetch pending session summary
- - submit verified email
- - choose create-new-account or bind-existing-account
- - submit nickname/avatar replacement choices
-- [ ] Make bind-existing-account and bind-current-user flows explicit:
- - no automatic linking on matching email
- - fresh password/TOTP proof is scoped to the intended target account only
- - no automatic metadata merge beyond explicitly selected nickname/avatar adoption
-- [ ] Update `backend/internal/handler/auth_handler.go` and `backend/internal/handler/user_handler.go` to expose:
- - current bindings summary
- - start-bind endpoints for LinuxDo/OIDC/WeChat
- - disconnect endpoints with safety checks
- - avatar upload/delete endpoints
-- [ ] Avatar handling requirements:
- - allow external URL
- - allow data URL upload
- - compress image payload to `<=100KB`
- - store compressed value in DB
- - deleting custom avatar must not implicitly resurrect stale provider avatar unless the user explicitly chooses provider avatar again
-
-### Task 9. Add admin visibility and sorting
-
-- [ ] Update `backend/internal/handler/admin/user_handler.go` and supporting query/service code so admin list supports:
- - `last_login_at`
- - `last_active_at`
- - sorting by both
- - binding/provider summary columns
-- [ ] Update `backend/internal/handler/admin/setting_handler.go` and setting service code for:
- - provider initial grant config
- - mandatory-email-on-third-party-signup config
- - payment source exclusivity config
- - advanced scheduler toggle
-
-### Task 10. Backend verification checkpoint
-
-- [ ] Run targeted backend tests:
- ```bash
- cd backend
- go test ./internal/repository -run 'TestUserProfileIdentity|TestAuthIdentityMigration'
- go test ./internal/service -run 'TestAuthIdentityFlow|TestPendingAuthIdentity|TestOpenAIAccountScheduler'
- go test ./internal/handler -run 'TestLinuxDo|TestOidc|TestWechat|TestPaymentWebhook'
- go test ./...
- ```
-- [ ] Commit checkpoint:
- ```bash
- git add backend
- git commit -m "feat: rebuild auth identity backend flows"
- ```
-
-## Phase 3: Frontend Third-Party Flow And Profile UX
-
-### Task 11. Rebuild callback flow UI around pending session decisions
-
-- [ ] Rebuild `frontend/src/components/auth/ThirdPartyAuthCallbackFlow.vue` so it:
- - loads pending-session summary from backend
- - shows provider nickname/avatar candidates
- - lets user independently choose nickname replacement and avatar replacement
- - handles create-new-account vs bind-existing-account
- - enforces verified local email before completion when required
- - handles “email already exists” by branching to bind-existing-account or change-email-and-create-new-account
-- [ ] Update:
- - `frontend/src/views/auth/LinuxDoCallbackView.vue`
- - `frontend/src/views/auth/OidcCallbackView.vue`
- - `frontend/src/views/auth/WechatCallbackView.vue`
- - `frontend/src/api/auth.ts`
- - `frontend/src/stores/auth.ts`
-- [ ] Replace any token-fragment bootstrap with backend session completion or one-time exchange code flow.
-- [ ] During rollout, keep temporary compatibility readers for legacy pending-token aliases behind a bounded bridge contract and explicit removal step.
-
-### Task 12. Rebuild profile account binding and avatar UX
-
-- [ ] Rebuild `frontend/src/components/user/profile/ProfileAccountBindingsCard.vue` to:
- - show linked LinuxDo/OIDC/WeChat providers
- - start bind/unbind flows
- - show provider avatars and nicknames as reference only
- - prevent unsafe disconnect when it would strand the account
-- [ ] Rebuild `frontend/src/components/user/profile/ProfileInfoCard.vue` and `frontend/src/views/user/ProfileView.vue` to:
- - support avatar URL entry
- - support data URL upload/compression preview
- - support avatar delete
- - clearly separate current profile nickname/avatar from provider-sourced suggested nickname/avatar
-
-### Task 13. Add frontend tests for rebuilt auth/profile flows
-
-- [ ] Add or update:
- - `frontend/src/components/auth/__tests__/ThirdPartyAuthCallbackFlow.spec.ts`
- - `frontend/src/components/auth/__tests__/LinuxDoCallbackView.spec.ts`
- - `frontend/src/components/auth/__tests__/WechatCallbackView.spec.ts`
- - `frontend/src/components/user/profile/__tests__/ProfileAccountBindingsCard.spec.ts`
- - `frontend/src/components/user/profile/__tests__/ProfileInfoCard.spec.ts`
-- [ ] Cover:
- - email-required branch
- - email-conflict branch
- - bind-existing-account with re-auth prompt
- - nickname replacement only
- - avatar replacement only
- - neither replacement
- - avatar delete after prior provider adoption
-
-## Phase 4: Payment Routing Rebuild
-
-### Task 14. Normalize payment routing backend
-
-- [ ] Rebuild `backend/internal/service/payment_config_service.go` to expose a stable method-routing contract:
- - frontend visible methods remain `alipay` and `wxpay`
- - admin chooses which provider instance serves each method
- - runtime validation guarantees only one active source per visible method
-- [ ] Add migration logic and tests to normalize historical provider-instance config:
- - `supported_types`
- - `payment_mode`
- - legacy aliases such as `wxpay_direct`
- - historical limit config
-- [ ] Rebuild `backend/internal/service/payment_order.go` and `backend/internal/service/payment_order_lifecycle.go` so order creation snapshots:
- - visible method
- - selected provider instance id
- - provider type
- - provider capability mode
- - verification-critical provider fields needed for later callback/query/refund validation
-- [ ] Rebuild `backend/internal/handler/payment_handler.go` for UX rules:
- - Alipay PC: QR page
- - Alipay mobile: direct jump
- - WeChat PC: QR page
- - WeChat H5 in WeChat: MP/JSAPI first, fallback to H5
- - WeChat H5 outside WeChat: H5 or “open in WeChat” prompt when unavailable
-- [ ] Never derive canonical return URL from `Referer`; use configured or signed internal callback targets only.
-- [ ] Implement `backend/internal/service/payment_resume_service.go` so WeChat in-app payment OAuth resume is server-backed rather than localStorage-backed:
- - create `oauth_required` resume context
- - persist amount/order_type/plan_id/visible method/redirect/state
- - redeem callback into same-origin internal resume target
- - expire and consume resume context safely
-
-### Task 15. Make fulfillment and reconciliation provider-instance-safe
-
-- [ ] Rebuild `backend/internal/handler/payment_webhook_handler.go` and `backend/internal/service/payment_fulfillment.go` so:
- - verification uses the order’s original provider instance
- - webhook processing is idempotent by provider event id and internal order id
- - missed webhook recovery uses server-side provider query, not frontend success return
-- [ ] For WeChat Pay specifically, enforce:
- - fixed HTTPS `notify_url` with no query params
- - no dependency on user login state
- - signature verification before decrypt
- - APIv3 decrypt before business parsing
- - comparison of `appid`, `mchid`, `out_trade_no`, `amount`, `currency`, and trade state against the order snapshot
-- [ ] Harden `frontend/src/views/user/PaymentResultView.vue` and `frontend/src/api/payment.ts` so result polling uses an authenticated order lookup or signed opaque token, not a raw public `out_trade_no` query.
-
-### Task 16. Rebuild payment frontend views
-
-- [ ] Rebuild `frontend/src/views/user/PaymentView.vue`, `frontend/src/views/user/PaymentQRCodeView.vue`, and `frontend/src/stores/payment.ts` so:
- - only two buttons are shown to user: `支付宝` and `微信支付`
- - frontend does not leak official-vs-EasyPay distinction
- - route-specific copy handles QR, jump, MP, H5 fallback correctly
-- [ ] Rebuild WeChat in-app payment resume UX around the server-backed resume context:
- - handle `oauth_required`
- - continue from same-origin resume target
- - avoid long-lived localStorage as the source of truth
-- [ ] Add or update:
- - `frontend/src/views/user/__tests__/PaymentView.spec.ts`
- - `frontend/src/views/user/__tests__/PaymentResultView.spec.ts`
- - backend webhook/payment routing tests
-
-### Task 17. Payment verification checkpoint
-
-- [ ] Run:
- ```bash
- cd backend
- go test ./internal/service -run 'TestPayment'
- go test ./internal/handler -run 'TestPayment'
- cd ../frontend
- pnpm test:run src/views/user/__tests__/PaymentView.spec.ts src/views/user/__tests__/PaymentResultView.spec.ts
- ```
-- [ ] Commit checkpoint:
- ```bash
- git add backend frontend
- git commit -m "feat: rebuild payment routing foundation"
- ```
-
-## Phase 5: Scheduler, Rollout, And Final Compatibility Pass
-
-### Task 18. Gate advanced scheduler behind explicit config
-
-- [ ] Update `backend/internal/service/openai_account_scheduler.go` and related admin setting surfaces so:
- - advanced scheduler remains compiled and testable
- - default runtime state is disabled
- - enablement is explicit through admin settings
- - legacy scheduling behavior remains default on upgrade
-- [ ] Add targeted coverage in `backend/internal/service/openai_account_scheduler_test.go`.
-
-### Task 19. Complete compatibility and rollout safety checks
-
-- [ ] Add migration/repository tests covering:
- - historical email-only user login after upgrade
- - historical LinuxDo user login after upgrade
- - historical synthetic-email LinuxDo user login after upgrade
- - historical synthetic-email WeChat user login after upgrade
- - historical synthetic-email OIDC user login after upgrade
- - historical WeChat `openid`-only rows are reported or explicitly remediated
- - no retroactive grant replay during migration
- - first-bind grant fires once only when enabled
- - email identity dual-write stays consistent
- - bind-existing-account requires password and TOTP where configured
- - mixed-version callback token bridge works during rollout and is removable afterward
- - historical payment config is normalized into visible-method routing without refund/query regression
-- [ ] Add deploy sequencing note to release docs or internal runbook:
- 1. deploy schema and backfill release.
- 2. inspect migration report for unmatched rows.
- 3. deploy backend identity/payment compatibility code with exchange bridge and legacy token aliases still enabled.
- 4. deploy frontend callback/profile/payment UI using session completion, exchange code, and server-backed WeChat payment resume.
- 5. remove legacy callback/token parsing after mixed-version window closes.
- 6. enable strict email-required signup or provider bind grants only after metrics are healthy.
-
-### Task 20. Final verification and handoff
-
-- [ ] Run final backend verification:
- ```bash
- cd backend
- go test ./...
- ```
-- [ ] Run targeted frontend verification:
- ```bash
- cd frontend
- pnpm test:run \
- src/components/auth/__tests__/ThirdPartyAuthCallbackFlow.spec.ts \
- src/components/auth/__tests__/LinuxDoCallbackView.spec.ts \
- src/components/auth/__tests__/WechatCallbackView.spec.ts \
- src/components/user/profile/__tests__/ProfileAccountBindingsCard.spec.ts \
- src/components/user/profile/__tests__/ProfileInfoCard.spec.ts \
- src/views/user/__tests__/PaymentView.spec.ts \
- src/views/user/__tests__/PaymentResultView.spec.ts
- ```
-- [ ] Run focused manual smoke checks:
- - email login with existing account
- - LinuxDo existing-account login after migration
- - WeChat synthetic-email account login after migration
- - OIDC synthetic-email account login after migration
- - third-party first login create-new-account path
- - third-party first login bind-existing-account path
- - first third-party bind with optional nickname/avatar replacement
- - PC Alipay QR
- - mobile Alipay jump
- - PC WeChat QR
- - WeChat H5 MP/JSAPI path
- - WeChat in-app OAuth resume path
- - non-WeChat H5 fallback path
-- [ ] Commit final checkpoint:
- ```bash
- git add docs backend frontend
- git commit -m "feat: rebuild auth identity and payment foundation"
- ```
-
-## Review Checklist
-
-- [ ] No flow still relies on provider email equality for account linking.
-- [ ] No flow still creates third-party users through plain email registration helpers.
-- [ ] No callback still returns first-party bearer tokens in URL fragments.
-- [ ] No callback completion path can be replayed as a bearer token substitute.
-- [ ] No payment result view trusts provider return page as authoritative fulfillment.
-- [ ] No webhook verification path selects provider credentials from “currently active config” instead of the order snapshot.
-- [ ] Existing email users, historical LinuxDo/WeChat/OIDC users, and `openid`-only WeChat remediation cases are covered by migration tests.
-- [ ] Avatar adoption and deletion semantics are explicit and reversible.
-- [ ] Grant timing is source-aware and one-time only.
diff --git a/docs/superpowers/specs/2026-04-20-auth-identity-payment-foundation-design.md b/docs/superpowers/specs/2026-04-20-auth-identity-payment-foundation-design.md
deleted file mode 100644
index 23823cf0..00000000
--- a/docs/superpowers/specs/2026-04-20-auth-identity-payment-foundation-design.md
+++ /dev/null
@@ -1,763 +0,0 @@
-# Auth Identity And Payment Foundation Design
-
-**Date:** 2026-04-20
-
-**Status:** Draft approved in conversation, written for implementation planning
-
-**Goal**
-
-Rebuild the `feat/auth-identity-foundation` intent on a clean branch from `main`, covering unified user identity, third-party login and binding, profile adoption, source-based signup defaults, unified payment routing and UX, admin configuration, compatibility with existing `main` data, and an opt-in OpenAI advanced scheduling switch.
-
-## Scope
-
-This design includes:
-
-- Email login and registration
-- Third-party login and binding for `LinuxDo`, `OIDC`, and `WeChat`
-- Unified identity storage for email and third-party identities
-- Pending auth sessions for callback-to-login/register/bind continuation
-- User-controlled nickname/avatar adoption during first relevant third-party flow
-- Profile binding management and avatar upload/delete
-- Source-based initial grants for balance, concurrency, and subscriptions
-- User management support for `last_login_at` and `last_active_at` sorting
-- Unified payment display methods (`alipay`, `wxpay`) mapped to a single active backend source each
-- Alipay and WeChat UX routing rules across PC, mobile, H5, and WeChat environments
-- Admin settings for auth providers, source defaults, payment sources, and OpenAI advanced scheduling
-- Incremental migration and compatibility for existing email users, existing LinuxDo users, historical LinuxDo/WeChat/OIDC synthetic-email users, and historical WeChat `openid`-only identity records
-
-This design does not treat unrelated upstream merges, docs churn, or license changes from the old branch as required scope.
-
-## Product Rules
-
-### Auth and identity
-
-- Existing email users remain valid and continue to log in with no manual action.
-- Existing LinuxDo, OIDC, and WeChat users represented by historical third-party or synthetic-email data must remain recoverable during migration.
-- Third-party first login behavior:
- - Existing bound identity: direct login
- - Missing identity: start first-login flow
-- Browser-based third-party authorization-code login always uses PKCE `S256`; this is not an admin-toggleable feature.
-- If `force_email_on_third_party_signup` is disabled, a first-login user may create an account without binding an email.
-- If `force_email_on_third_party_signup` is enabled, the user must provide an email.
-- If the provided and verified email already exists:
- - show that the email already exists
- - allow "verify and bind existing account"
- - allow "change email and continue registration"
- - do not allow bypassing the email requirement
-- Upstream provider email verification is not trusted as a local bound email.
-- Matching upstream email must never auto-link to an existing local account.
-- Linking to an existing local account is allowed only when:
- - the user explicitly chooses that target account
- - the target account passes fresh local re-authentication
- - required TOTP verification succeeds
-- New third-party bind initiated from profile must start from an already logged-in local account and preserve explicit bind intent end-to-end.
-- `redirect_to` may only represent a normalized same-origin internal route. It must never contain a third-party URL and must never be derived from `Referer`.
-- OIDC validation rules:
- - canonical identity key is `issuer + sub`
- - discovery issuer and ID token `iss` must match exactly
- - `userinfo.sub` must match ID token `sub` when UserInfo is used
- - upstream `email_verified` may improve UX copy but does not satisfy local email-binding requirements
-- WeChat login chooses channel by environment:
- - in WeChat environment: `mp`
- - outside WeChat: `open`
-- WeChat primary identity key is `unionid`.
-- If a WeChat login/bind flow cannot produce `unionid`, the flow fails and no fallback `openid` identity is created.
-- Historical WeChat records that only contain `openid` are treated as migration-remediation cases, not as a valid long-term canonical identity model.
-- WeChat website login uses authorization code flow, random `state`, and the provider channel/app binding must be persisted alongside the resolved identity.
-
-### Profile adoption
-
-- During the first relevant third-party flow, the user can independently decide:
- - replace current nickname or not
- - replace current avatar or not
-- This applies to first third-party registration and first third-party binding.
-- The decision is explicit user choice, not automatic replacement.
-
-### Source-based initial grants
-
-- Source-specific defaults exist for `email`, `linuxdo`, `oidc`, and `wechat`.
-- Each source defines:
- - default balance
- - default concurrency
- - default subscriptions
- - grant on signup
- - grant on first bind
-- Default behavior:
- - grant on signup: enabled
- - grant on first bind: disabled
-- First-bind grants are optional and controlled per source.
-- Grants must be idempotent.
-
-### Avatar management
-
-- Avatar supports:
- - external URL
- - image `data:` URL
-- `data:` URL images are compressed to at most `100KB` before persistence.
-- Avatar storage is database-backed.
-- Avatar delete is supported.
-
-### Payment UX and routing
-
-- Frontend shows only two display methods:
- - `alipay`
- - `wxpay`
-- Users never choose between official providers and EasyPay explicitly.
-- Backend allows only one active source per display method at a time.
-- Alipay UX:
- - PC: show QR code in page
- - mobile: jump to Alipay app/payment flow
-- WeChat UX:
- - PC: show QR code in page
- - non-WeChat H5: prefer H5 pay; if unavailable, tell the user to open in WeChat
- - WeChat environment: prefer MP/JSAPI pay; if unavailable, fall back to H5 pay
-- Payment success is confirmed by backend order state, webhook, and/or query, not only frontend return.
-- Frontend-visible labels remain `支付宝` and `微信支付`, while internal visible-method identifiers remain `alipay` and `wxpay`.
-- Public result pages must not verify order state by exposing raw `out_trade_no`; they use authenticated lookup or a signed opaque result token instead.
-- Payment callback or return URLs must be fixed same-origin internal targets. They must not be inferred from `Referer`.
-- WeChat payment webhook handling must use a fixed HTTPS `notify_url` with no query parameters and must not depend on user login state.
-
-### OpenAI advanced scheduling
-
-- OpenAI advanced scheduling is supported.
-- It is disabled by default.
-- Admin can enable it explicitly.
-
-## Architecture
-
-Keep `users` as the account owner table and move login identities, channel mappings, pending auth state, callback completion state, and first-bind grant idempotency into dedicated tables and services. Keep email login working while progressively introducing unified identity reads and writes.
-
-Payment uses a similar split between user-visible display methods and backend provider sources. Frontend works only with stable display methods while backend resolves to the currently active source and capability matrix, and stores enough order-time snapshot data to survive later provider-config changes.
-
-Compatibility is a first-class concern: migrations are additive, reads are compatibility-aware, and rollout must tolerate existing `main` data and short-lived frontend/backend version skew.
-
-## Data Model
-
-### `users`
-
-Preserve existing account ownership and local-login fields. Extend or use:
-
-- `email`
-- `password_hash`
-- `totp_enabled`
-- `signup_source`
-- `last_login_at`
-- `last_active_at`
-
-The `users` table remains the primary business subject for balance, concurrency, subscriptions, permissions, and profile.
-
-### `auth_identities`
-
-Represents all canonical login or bindable identities.
-
-Fields:
-
-- `user_id`
-- `provider_type`: `email`, `linuxdo`, `oidc`, `wechat`
-- `provider_key`
-- `provider_subject`
-- `verified_at`
-- `issuer`
-- `metadata`
-- timestamps
-
-Uniqueness:
-
-- `provider_type + provider_key + provider_subject` must be unique
-
-Rules:
-
-- email identity uses canonicalized local email
-- LinuxDo uses stable provider subject under the configured provider namespace
-- OIDC uses stable issuer + subject, with issuer namespace represented consistently through `provider_key` and `issuer`
-- WeChat uses `unionid` as canonical subject under the configured Open Platform namespace
-
-### `auth_identity_channels`
-
-Stores channel-specific subject mappings for an identity.
-
-Primary use:
-
-- WeChat `open` / `mp` / payment channel mapping
-
-Fields:
-
-- `identity_id`
-- `provider_type`
-- `provider_key`
-- `channel`
-- `channel_app_id`
-- `channel_subject`
-- `metadata`
-- timestamps
-
-Rules:
-
-- canonical WeChat identity still keys on `unionid`
-- `openid` values live here as channel mappings
-
-### `pending_auth_sessions`
-
-Stores callback state between third-party callback and final account action.
-
-Fields:
-
-- `intent`
-- `provider_type`
-- `provider_key`
-- `provider_subject`
-- `target_user_id`
-- `redirect_to`
-- `resolved_email`
-- `registration_password_hash`
-- `upstream_identity_claims`
-- `local_flow_state`
-- `browser_session_key`
-- `completion_code_hash`
-- `completion_code_expires_at`
-- `email_verified_at`
-- `password_verified_at`
-- `totp_verified_at`
-- `expires_at`
-- `consumed_at`
-- timestamps
-
-Responsibilities:
-
-- continue provider callback into register/login/bind flows
-- persist nickname/avatar suggestions
-- persist explicit adoption decisions
-- survive navigation between auth pages
-- support mixed-version rollout through short-lived legacy token aliases when required
-
-Security rules:
-
-- callback completion uses backend session completion or a one-time exchange code
-- exchange codes are short-lived, one-time, bound to browser session and pending session, and redeemed via `POST`
-- exchange codes must not behave as bearer tokens and must not be logged, stored in URL fragments, or reused after redemption
-- `local_flow_state` stores mutable local progression only; immutable upstream claims remain in `upstream_identity_claims`
-
-### `identity_adoption_decisions`
-
-Persists user adoption preference collected during a pending-auth flow and resolved onto the bound identity.
-
-Fields:
-
-- `pending_auth_session_id`
-- `identity_id`
-- `adopt_display_name`
-- `adopt_avatar`
-- `decided_at`
-- timestamps
-
-Rules:
-
-- one adoption-decision row exists per pending session
-- `identity_id` is filled once final account creation or bind succeeds
-
-### `user_avatars`
-
-Stores the currently effective custom avatar.
-
-Fields:
-
-- `user_id`
-- `storage_provider`
-- `storage_key`
-- `url`
-- `content_type`
-- `byte_size`
-- `sha256`
-- timestamps
-
-Rules:
-
-- supports URL-backed and inline data-backed representations
-- hard maximum payload size is `100KB`
-
-### `user_provider_default_grants`
-
-Stores idempotency state for source grants.
-
-Fields:
-
-- `user_id`
-- `provider_type`
-- `granted_at`
-- timestamps
-
-Responsibilities:
-
-- prevent duplicate first-bind grants
-- allow signup grants and first-bind grants to be reasoned about independently
-
-## Identity Keys And Canonicalization
-
-- Email canonical key: `lower(trim(email))`
-- LinuxDo canonical key: provider subject from LinuxDo
-- OIDC canonical key: `issuer + sub`
-- WeChat canonical key: `unionid`
-
-WeChat-specific rule:
-
-- `openid` never becomes the primary stored identity key
-- if only `openid` is available, login/bind fails with a configuration/identity error
-- historical `openid`-only records must be reported and either remediated during migration or explicitly blocked from silent auto-upgrade
-
-## Core Flows
-
-### Email register/login
-
-- Existing email auth flow remains
-- On email registration, create canonical `email` identity
-- Apply `email` source signup defaults
-
-### Third-party login with existing identity
-
-- Resolve canonical identity
-- Login mapped `user`
-- Update `last_login_at`
-- Do not issue signup or first-bind grants again
-
-### Third-party first login with no identity
-
-- Create `pending_auth_session`
-- Frontend callback flow decides next action
-- Pending session creation stores immutable upstream claims separately from mutable local progress fields
-
-Branches:
-
-- no forced email binding:
- - user can create account directly
-- forced email binding:
- - user must supply local email
-
-If supplied local email already exists:
-
-- tell the user the email already exists
-- allow verify-and-bind-existing-account
-- allow changing email to continue registration
-
-On new account creation:
-
-- create `users` row
-- create canonical third-party identity
-- create or update canonical email identity when local email binding succeeds
-- apply source signup grants
-- apply adoption choices if selected
-
-### Bind third-party identity to current logged-in user
-
-- current user starts bind flow
-- callback resolves to `bind_current_user`
-- bind intent is tied to the initiating local user session and cannot be re-targeted by email match
-- bind canonical identity to current user
-- if configured and first bind for that provider, apply first-bind grants
-- present nickname/avatar replacement choice
-
-### Bind existing account during first-login flow
-
-- user explicitly selects bind-existing-account
-- verify password for existing account
-- if account requires TOTP, verify TOTP
-- bind canonical identity to target account
-- optionally apply first-bind grants
-- present nickname/avatar replacement choice
-- no automatic profile or metadata merge occurs beyond explicitly selected nickname/avatar replacement
-
-### Callback completion and exchange flow
-
-- third-party callback never returns first-party bearer tokens in URL fragments
-- callback completion uses either:
- - backend session completion tied to the initiating browser session
- - one-time opaque exchange code redeemed by `POST`
-- mixed-version rollout may temporarily emit legacy pending token aliases in addition to the new completion path
-- legacy alias support is transitional and bounded to rollout windows only
-
-### WeChat login and channel mapping
-
-- environment chooses `mp` or `open`
-- website login uses authorization-code flow with provider-configured app/channel binding
-- callback must resolve to `unionid`
-- channel `openid` is optionally recorded in `auth_identity_channels`
-- failure to obtain `unionid` aborts flow
-
-### Avatar upload and delete
-
-- URL avatar: validate and persist reference
-- data URL avatar:
- - decode
- - validate image type
- - compress to `<=100KB`
- - persist database-backed inline representation
-- delete removes current custom avatar entry
-
-## Payment Routing Model
-
-### User-visible methods
-
-- `alipay`
-- `wxpay`
-
-### Backend source abstraction
-
-Each display method maps to exactly one active configured backend source:
-
-- `official_alipay`
-- `easypay_alipay`
-- `official_wechat`
-- `easypay_wechat`
-
-Frontend submits display method only. Backend resolves display method to active source and capability set.
-
-### Legacy payment-config normalization
-
-- existing provider-instance `supported_types`, legacy aliases such as `wxpay_direct`, and per-type limit structures are migrated into the visible-method model
-- migration preserves historical payment capability and refund semantics
-- the system keeps one normalized visible-method mapping per provider instance for rollout and audit
-
-### Alipay routing
-
-- PC: create QR-oriented result and show QR in page
-- mobile: create jump/redirect-oriented result
-
-### WeChat routing
-
-- PC: QR result
-- non-WeChat H5:
- - prefer H5 pay
- - if unavailable, show "open in WeChat" requirement
-- WeChat environment:
- - prefer MP/JSAPI
- - if unavailable, fall back to H5 pay
-
-### WeChat payment OAuth recovery
-
-- if WeChat in-app payment requires `openid` and the current request does not already hold it, backend returns an `oauth_required` response instead of guessing
-- backend creates a server-backed payment-resume context containing:
- - target visible method
- - amount/order type/plan context
- - redirect target
- - anti-replay state
-- backend redirects through a dedicated WeChat payment OAuth start endpoint
-- callback exchanges the provider code server-side, stores `openid` in the payment-resume context, and returns a same-origin internal resume target
-- frontend resumes the original order flow through the resume context instead of trusting raw callback query state or long-lived local storage
-
-### Payment completion
-
-- frontend return restores context and UI state
-- backend order state remains source of truth
-- webhook and/or order query remain authoritative for fulfillment
-- order fulfillment validates webhook or query payload against order-time snapshot data including provider instance, merchant identifiers, amount, currency, and provider order references
-- result pages use authenticated lookup or signed opaque result tokens, never raw public `out_trade_no`
-
-## Admin Configuration Model
-
-### Auth provider settings
-
-- email registration and verification settings
-- force email on third-party signup
-- LinuxDo client settings
-- OIDC issuer/client settings and provider display name
-- WeChat `open` / `mp` capability indicators derived from environment-backed configuration, surfaced to the frontend/admin read models as effective availability rather than full in-panel credential editing
-
-### Source default settings
-
-Per source (`email`, `linuxdo`, `oidc`, `wechat`):
-
-- default balance
-- default concurrency
-- default subscriptions
-- grant on signup
-- grant on first bind
-
-### Payment settings
-
-- active source for `alipay`
-- active source for `wechat`
-- source-specific credentials and enablement
-- effective WeChat payment capabilities may differ by enabled provider instances and selected visible-method source:
- - QR available
- - H5 available
- - MP/JSAPI available
-
-### Scheduling settings
-
-- OpenAI advanced scheduling enabled/disabled
-- default disabled
-
-## Compatibility And Rollout
-
-Compatibility is mandatory, especially for:
-
-- existing email users
-- existing LinuxDo users
-- historical LinuxDo synthetic-email accounts
-- historical WeChat synthetic-email accounts
-- historical OIDC synthetic-email accounts
-- historical WeChat `openid`-only records created by older branches
-
-### Additive migrations
-
-- preserve existing `users` data and behavior
-- add identity and pending-session tables
-- avoid destructive schema swaps
-
-### Migration backfill
-
-- backfill canonical `email` identities for valid existing email users
-- backfill canonical `linuxdo` identities during migration for historical synthetic-email LinuxDo users
-- backfill canonical `wechat` and `oidc` identities when historical synthetic-email or `user_external_identities` data allows deterministic reconstruction
-- emit migration reports for historical WeChat `openid`-only records that cannot be safely promoted to canonical `unionid`
-- backfill must be idempotent and repeatable
-
-### Compatibility reads
-
-During rollout:
-
-- read new identity model first
-- where necessary, retain compatibility logic for existing email and historical LinuxDo/WeChat/OIDC synthetic-email recognition
-
-### Grant idempotency
-
-- migration backfill must not trigger signup or first-bind grants
-- first-bind grants must use explicit idempotency tracking
-
-### API compatibility
-
-Retain transitional support for legacy/new request and response shapes where needed, including:
-
-- `pending_auth_token`
-- `pending_oauth_token`
-- old callback parsing expectations
-- historical profile field mappings
-- legacy callback fragment readers during the bounded rollout window
-
-### Settings and payment compatibility
-
-- preserve existing payment configs and order semantics from `main`
-- add new settings incrementally
-- avoid rewriting the entire settings schema in one cutover
-- preserve legacy provider-instance capabilities by explicitly mapping historical `supported_types`, `payment_mode`, and limit config into normalized visible-method routing
-
-### Rolling upgrade tolerance
-
-- do not assume simultaneous frontend/backend deployment
-- new backend must tolerate short-lived older frontend request shapes
-- rollout must define the deployment order and removal point for legacy callback token parsing and legacy payment resume parsing
-
-## Testing Strategy
-
-### Repository tests
-
-- identity upsert and lookup
-- WeChat channel mapping
-- pending auth session persistence
-- source grant idempotency
-- avatar persistence and delete
-- migration backfill behavior
-
-### Service tests
-
-- direct login by existing identity
-- first third-party signup
-- forced email flow
-- existing-email bind-existing-account flow
-- first-bind grant on/off
-- nickname/avatar adoption choices
-- WeChat `unionid` required behavior
-- payment routing resolution
-
-### Handler and route tests
-
-- LinuxDo/OIDC/WeChat callback handling
-- bind-existing
-- bind-current-user
-- create-account
-- TOTP continuation
-- payment create and recovery
-
-### Frontend tests
-
-- third-party callback flow state machine
-- register/login continuation
-- profile bindings card
-- avatar interactions
-- payment page routing behavior
-- admin settings UI
-
-### Compatibility tests
-
-- existing email users
-- historical LinuxDo synthetic-email users
-- historical WeChat synthetic-email users
-- historical OIDC synthetic-email users
-- historical WeChat `openid`-only records reported or remediated correctly
-- historical payment config
-- legacy auth payload field names
-- historical payment result handling
-- mixed-version callback token bridge behavior
-
-## Implementation Phases
-
-1. Add schema, migrations, compatibility backfill, and repository support
-2. Implement unified identity services and pending auth session flows
-3. Integrate profile binding, avatar, and adoption decision flows
-4. Add per-source default grants and admin config surfaces
-5. Rebuild payment routing abstraction and frontend payment UX
-6. Add user-management sorting and OpenAI advanced scheduling switch
-7. Run compatibility, rollout, and regression hardening
-
-## External Constraints And Best Practices
-
-Implementation must follow current primary-source guidance:
-
-- OAuth 2.0 Security BCP (RFC 9700): strict redirect handling, state protection, mix-up resistant design
-- PKCE (RFC 7636): require `S256` on browser authorization-code flows
-- OpenID Connect Core: stable issuer/subject handling for OIDC identities
-- Account linking best practice: require explicit user confirmation or re-authentication before linking to existing accounts
-- WeChat UnionID and website-login guidance: treat `unionid` as canonical cross-channel subject and persist channel/app binding with website login responses
-- WeChat Pay webhook guidance: verify signatures, decrypt payloads, and confirm merchant/order/amount fields against order-time state before fulfillment
-- Payment success-page guidance: custom success pages are informational and must not be the only fulfillment trigger
-
-References:
-
-- RFC 9700:
-- RFC 7636:
-- OpenID Connect Core 1.0:
-- Auth0 account linking guidance:
-- WeChat UnionID guidance:
-- WeChat website login guidance:
-- WeChat Pay callback/signature guidance:
-- Stripe Checkout fulfillment guidance:
-
-## Audit Synthesis
-
-The clean rebuild direction is not to copy either existing branch directly.
-
-- `feat/auth-identity-foundation` has the better long-term model:
- - unified auth identities
- - pending auth sessions
- - identity adoption decisions
- - provider-scoped default grants
- - payment display-method abstraction
- - OpenAI advanced scheduler layering
-- `personal-dev-branch` has the better real-world closure:
- - LinuxDo and WeChat callback flows are more operationally complete
- - profile binding and avatar UX is more complete
- - historical synthetic-email users across multiple providers are recognized and recovered in live flows
- - WeChat payment OAuth and recovery behavior is more complete
-- Primary-source guidance supplies hard constraints for OAuth/OIDC, account linking, WeChat identity handling, and payment completion semantics.
-
-The final rebuild must therefore:
-
-- keep the `feat/auth-identity-foundation` data model direction
-- absorb the strongest business-flow behavior from `personal-dev-branch`
-- reject transitional or half-finished behavior from both branches
-- treat compatibility and rollout as first-class implementation scope
-
-## Keep / Adapt / Drop
-
-### Keep
-
-Keep these architectural choices essentially intact:
-
-- `auth_identities`, `auth_identity_channels`, `pending_auth_sessions`, `identity_adoption_decisions`
-- per-provider default grants with one-time grant tracking
-- WeChat canonical identity plus channel mapping model
-- pending-auth verification gates before final bind
-- payment visible-method abstraction (`alipay`, `wechat`) decoupled from backend provider source
-- OpenAI advanced scheduler layering and test-backed behavior
-
-Keep these operational flow ideas from `personal-dev-branch`:
-
-- LinuxDo pending identity callback flow
-- WeChat pending identity callback flow
-- profile bindings UX and “cannot disconnect last usable login method” rule
-- separate WeChat login OAuth and WeChat payment OAuth entry points
-- historical synthetic-email recognition logic as a migration bridge
-- explicit WeChat payment OAuth recovery protocol as a product requirement, but reimplemented with server-backed resume state
-
-### Adapt
-
-These areas must be reimplemented with the same intent but stricter boundaries:
-
-- third-party account creation from pending-auth state must be transactional and must not register a plain local user before identity finalization succeeds
-- email identity lifecycle must become real dual-write state, not just one migration-time backfill
-- `signup_source` must be backfilled more accurately for known historical third-party users
-- WeChat payment recovery state must move from frontend-only storage to server-backed continuation state
-- avatar adoption fetches must be security-hardened and failure-visible
-- pending-auth payload modeling must clearly separate immutable upstream payload from mutable local metadata
-- callback completion must use a real exchange/session model instead of fragment-delivered bearer tokens
-- profile binding/avatar DTOs must be simplified to one authoritative backend contract instead of sprawling frontend fallback parsing
-- admin settings should preserve capability while reducing duplicated or transitional config branches
-
-### Drop
-
-Drop these as long-term design choices:
-
-- `user_external_identities` as the primary long-term identity model
-- synthetic email as a long-term canonical identity representation
-- OIDC as a side-path that does not participate in the same identity foundation as LinuxDo and WeChat
-- frontend multi-endpoint probing and broad compatibility parsing once the clean branch becomes the sole supported contract
-- unrelated branch noise such as generated-file churn, locale-only churn, or upstream merge residue as design inputs
-
-## Audit-Driven Hard Constraints
-
-The audit and source review establish these hard constraints:
-
-### Auth
-
-- all browser authorization-code providers use PKCE `S256` and do not expose an admin-off switch
-- callback handling uses strict `redirect_uri` discipline and state validation
-- OIDC identity key is `issuer + sub`
-- existing-account linking after email conflict must require explicit user action plus local-account verification
-- WeChat canonical identity key is `unionid`; `openid` is channel-scoped only
-
-### Compatibility
-
-- existing email users must continue to work with no manual intervention
-- existing LinuxDo users must not split into duplicate accounts
-- historical LinuxDo/WeChat/OIDC synthetic-email users must be backfilled into canonical identities during migration when deterministic recovery is possible
-- historical WeChat `openid`-only records must be surfaced through migration reporting and explicit remediation rules
-- migration backfills must not trigger signup or first-bind grants
-- legacy `pending_auth_token` and `pending_oauth_token` contracts must remain accepted during rollout
-- legacy auth/public setting aliases needed by older frontend builds must remain available during rollout
-- existing payment configs and historical order semantics must remain valid
-
-### Payment
-
-- frontend return pages do not determine final payment success
-- backend order state, webhook processing, and/or provider status query remain authoritative
-- each visible method (`alipay`, `wxpay`) may have only one active backend source at a time
-- public result pages must not expose raw `out_trade_no` lookup
-- WeChat Pay callback handling must verify signature, decrypt payload, and compare order fields against order-time snapshot data
-
-## Known Risks To Eliminate In Implementation
-
-These are specifically observed problems in the existing branches that the clean rebuild must eliminate:
-
-- third-party forced-email account creation currently bypasses the provider-aware account creation path and can leave orphan local accounts if bind finalization fails
-- post-migration email accounts are not fully dual-written into `auth_identities`
-- avatar adoption currently risks silent failure and insecure outbound fetch behavior
-- pending-auth payload responsibilities are internally inconsistent
-- OIDC parity is incomplete in `personal-dev-branch`; it must become a first-class provider in the unified identity model
-- WeChat union/open/channel identity handling is conceptually correct in the feature branch but still partially transitional across the codebase
-- WeChat payment recovery in `personal-dev-branch` is frontend-local and not robust across tabs or concurrent attempts
-- the existing pending-auth migration update is too destructive to reuse unchanged in a safer rollout
-- historical provider provenance should not be permanently flattened to `signup_source = email`
-- design/plan drift can reintroduce ambiguous identity uniqueness or ambiguous adoption-decision ownership if not aligned before implementation
-
-## Rollout Gates
-
-The rebuild is not ready for rollout until all of these are satisfied:
-
-1. Identity schema and migration chain are linearized and production-safe.
-2. Email identity backfill is complete and idempotent.
-3. Historical LinuxDo/WeChat/OIDC synthetic-email backfill to canonical identity is complete where deterministic, and non-recoverable rows are reported.
-4. Historical WeChat `openid`-only rows are either remediated or explicitly blocked with operator-visible reporting.
-5. `signup_source` backfill is accurate for known historical provider-created users.
-6. Dual token acceptance, exchange bridge behavior, and required legacy field aliases are present for the bounded rollout window.
-7. Existing payment configs are normalized and verified against current frontend-visible capabilities.
-8. New frontend flows are verified against mixed-version backend compatibility windows.
-9. Duplicate-account creation, first-bind grants, and payment route selection have regression coverage.
--
GitLab
From ed01c599161acc67a18ef19c73daeb9fbe1243ab Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 14:54:53 +0800
Subject: [PATCH 148/261] feat: track authenticated user activity
---
.../server/middleware/admin_auth_test.go | 29 ++++++++
.../internal/server/middleware/jwt_auth.go | 16 ++++-
.../server/middleware/jwt_auth_test.go | 59 ++++++++++++++++
.../service/admin_service_apikey_test.go | 3 +
.../service/admin_service_delete_test.go | 12 ++++
.../admin_service_email_identity_sync_test.go | 4 ++
backend/internal/service/user_service.go | 68 +++++++++++++++++++
backend/internal/service/user_service_test.go | 65 ++++++++++++++----
frontend/src/views/admin/UsersView.vue | 17 ++---
.../views/admin/__tests__/UsersView.spec.ts | 7 +-
10 files changed, 254 insertions(+), 26 deletions(-)
diff --git a/backend/internal/server/middleware/admin_auth_test.go b/backend/internal/server/middleware/admin_auth_test.go
index ed2578c8..cc5bead3 100644
--- a/backend/internal/server/middleware/admin_auth_test.go
+++ b/backend/internal/server/middleware/admin_auth_test.go
@@ -7,6 +7,7 @@ import (
"net/http"
"net/http/httptest"
"testing"
+ "time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
@@ -153,6 +154,18 @@ func (s *stubUserRepo) Delete(ctx context.Context, id int64) error {
panic("unexpected Delete call")
}
+func (s *stubUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*service.UserAvatar, error) {
+ return nil, nil
+}
+
+func (s *stubUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
+ panic("unexpected UpsertUserAvatar call")
+}
+
+func (s *stubUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error {
+ panic("unexpected DeleteUserAvatar call")
+}
+
func (s *stubUserRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
@@ -161,6 +174,18 @@ func (s *stubUserRepo) ListWithFilters(ctx context.Context, params pagination.Pa
panic("unexpected ListWithFilters call")
}
+func (s *stubUserRepo) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) {
+ panic("unexpected GetLatestUsedAtByUserIDs call")
+}
+
+func (s *stubUserRepo) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) {
+ panic("unexpected GetLatestUsedAtByUserID call")
+}
+
+func (s *stubUserRepo) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error {
+ panic("unexpected UpdateUserLastActiveAt call")
+}
+
func (s *stubUserRepo) UpdateBalance(ctx context.Context, id int64, amount float64) error {
panic("unexpected UpdateBalance call")
}
@@ -189,6 +214,10 @@ func (s *stubUserRepo) AddGroupToAllowedGroups(ctx context.Context, userID int64
panic("unexpected AddGroupToAllowedGroups call")
}
+func (s *stubUserRepo) ListUserAuthIdentities(ctx context.Context, userID int64) ([]service.UserAuthIdentityRecord, error) {
+ panic("unexpected ListUserAuthIdentities call")
+}
+
func (s *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
panic("unexpected UpdateTotpSecret call")
}
diff --git a/backend/internal/server/middleware/jwt_auth.go b/backend/internal/server/middleware/jwt_auth.go
index 4aceb355..48cb9004 100644
--- a/backend/internal/server/middleware/jwt_auth.go
+++ b/backend/internal/server/middleware/jwt_auth.go
@@ -1,6 +1,7 @@
package middleware
import (
+ "context"
"errors"
"strings"
@@ -11,11 +12,19 @@ import (
// NewJWTAuthMiddleware 创建 JWT 认证中间件
func NewJWTAuthMiddleware(authService *service.AuthService, userService *service.UserService) JWTAuthMiddleware {
- return JWTAuthMiddleware(jwtAuth(authService, userService))
+ return JWTAuthMiddleware(jwtAuth(authService, userService, userService))
+}
+
+type jwtUserReader interface {
+ GetByID(ctx context.Context, id int64) (*service.User, error)
+}
+
+type userActivityToucher interface {
+ TouchLastActiveForUser(ctx context.Context, user *service.User)
}
// jwtAuth JWT认证中间件实现
-func jwtAuth(authService *service.AuthService, userService *service.UserService) gin.HandlerFunc {
+func jwtAuth(authService *service.AuthService, userService jwtUserReader, activityToucher userActivityToucher) gin.HandlerFunc {
return func(c *gin.Context) {
// 从Authorization header中提取token
authHeader := c.GetHeader("Authorization")
@@ -73,6 +82,9 @@ func jwtAuth(authService *service.AuthService, userService *service.UserService)
Concurrency: user.Concurrency,
})
c.Set(string(ContextKeyUserRole), user.Role)
+ if activityToucher != nil {
+ activityToucher.TouchLastActiveForUser(c.Request.Context(), user)
+ }
c.Next()
}
diff --git a/backend/internal/server/middleware/jwt_auth_test.go b/backend/internal/server/middleware/jwt_auth_test.go
index c483a51e..84fd6967 100644
--- a/backend/internal/server/middleware/jwt_auth_test.go
+++ b/backend/internal/server/middleware/jwt_auth_test.go
@@ -9,6 +9,7 @@ import (
"net/http"
"net/http/httptest"
"testing"
+ "time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -30,6 +31,25 @@ func (r *stubJWTUserRepo) GetByID(_ context.Context, id int64) (*service.User, e
return u, nil
}
+func (r *stubJWTUserRepo) GetUserAvatar(_ context.Context, _ int64) (*service.UserAvatar, error) {
+ return nil, nil
+}
+
+func (r *stubJWTUserRepo) UpdateUserLastActiveAt(_ context.Context, _ int64, _ time.Time) error {
+ return nil
+}
+
+type recordingActivityToucher struct {
+ userIDs []int64
+}
+
+func (r *recordingActivityToucher) TouchLastActiveForUser(_ context.Context, user *service.User) {
+ if user == nil {
+ return
+ }
+ r.userIDs = append(r.userIDs, user.ID)
+}
+
// newJWTTestEnv 创建 JWT 认证中间件测试环境。
// 返回 gin.Engine(已注册 JWT 中间件)和 AuthService(用于生成 Token)。
func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthService) {
@@ -106,6 +126,45 @@ func TestJWTAuth_ValidToken_LowercaseBearer(t *testing.T) {
require.Equal(t, http.StatusOK, w.Code)
}
+func TestJWTAuth_ValidToken_TouchesLastActive(t *testing.T) {
+ user := &service.User{
+ ID: 1,
+ Email: "test@example.com",
+ Role: "user",
+ Status: service.StatusActive,
+ Concurrency: 5,
+ TokenVersion: 1,
+ }
+
+ gin.SetMode(gin.TestMode)
+
+ cfg := &config.Config{}
+ cfg.JWT.Secret = "test-jwt-secret-32bytes-long!!!"
+ cfg.JWT.AccessTokenExpireMinutes = 60
+
+ userRepo := &stubJWTUserRepo{users: map[int64]*service.User{1: user}}
+ authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
+ userSvc := service.NewUserService(userRepo, nil, nil, nil)
+ toucher := &recordingActivityToucher{}
+
+ r := gin.New()
+ r.Use(jwtAuth(authSvc, userSvc, toucher))
+ r.GET("/protected", func(c *gin.Context) {
+ c.Status(http.StatusOK)
+ })
+
+ token, err := authSvc.GenerateToken(user)
+ require.NoError(t, err)
+
+ w := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/protected", nil)
+ req.Header.Set("Authorization", "Bearer "+token)
+ r.ServeHTTP(w, req)
+
+ require.Equal(t, http.StatusOK, w.Code)
+ require.Equal(t, []int64{1}, toucher.userIDs)
+}
+
func TestJWTAuth_MissingAuthorizationHeader(t *testing.T) {
router, _ := newJWTTestEnv(nil)
diff --git a/backend/internal/service/admin_service_apikey_test.go b/backend/internal/service/admin_service_apikey_test.go
index e2eae0b4..aab35d25 100644
--- a/backend/internal/service/admin_service_apikey_test.go
+++ b/backend/internal/service/admin_service_apikey_test.go
@@ -88,6 +88,9 @@ func (s *userRepoStubForGroupUpdate) GetLatestUsedAtByUserIDs(context.Context, [
func (s *userRepoStubForGroupUpdate) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) {
panic("unexpected")
}
+func (s *userRepoStubForGroupUpdate) UpdateUserLastActiveAt(context.Context, int64, time.Time) error {
+ panic("unexpected")
+}
func (s *userRepoStubForGroupUpdate) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
panic("unexpected")
}
diff --git a/backend/internal/service/admin_service_delete_test.go b/backend/internal/service/admin_service_delete_test.go
index ac1d8ee7..126faad9 100644
--- a/backend/internal/service/admin_service_delete_test.go
+++ b/backend/internal/service/admin_service_delete_test.go
@@ -107,6 +107,18 @@ func (s *userRepoStub) ListWithFilters(ctx context.Context, params pagination.Pa
panic("unexpected ListWithFilters call")
}
+func (s *userRepoStub) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) {
+ panic("unexpected GetLatestUsedAtByUserIDs call")
+}
+
+func (s *userRepoStub) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) {
+ panic("unexpected GetLatestUsedAtByUserID call")
+}
+
+func (s *userRepoStub) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error {
+ panic("unexpected UpdateUserLastActiveAt call")
+}
+
func (s *userRepoStub) UpdateBalance(ctx context.Context, id int64, amount float64) error {
panic("unexpected UpdateBalance call")
}
diff --git a/backend/internal/service/admin_service_email_identity_sync_test.go b/backend/internal/service/admin_service_email_identity_sync_test.go
index d6a7af9a..eaf4e84b 100644
--- a/backend/internal/service/admin_service_email_identity_sync_test.go
+++ b/backend/internal/service/admin_service_email_identity_sync_test.go
@@ -97,6 +97,10 @@ func (s *emailSyncRepoStub) GetLatestUsedAtByUserID(context.Context, int64) (*ti
return nil, nil
}
+func (s *emailSyncRepoStub) UpdateUserLastActiveAt(context.Context, int64, time.Time) error {
+ return nil
+}
+
func (s *emailSyncRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil }
func (s *emailSyncRepoStub) DeductBalance(context.Context, int64, float64) error { return nil }
diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go
index e6053984..c6bf14c2 100644
--- a/backend/internal/service/user_service.go
+++ b/backend/internal/service/user_service.go
@@ -19,10 +19,13 @@ import (
"log/slog"
"net/url"
"sort"
+ "strconv"
"strings"
+ "sync"
"time"
xdraw "golang.org/x/image/draw"
+ "golang.org/x/sync/singleflight"
)
var (
@@ -47,6 +50,8 @@ const (
notifyCodeUserRateWindow = 10 * time.Minute
defaultUserIdentityRedirect = "/settings/profile"
+ userLastActiveMinTouch = 10 * time.Minute
+ userLastActiveFailBackoff = 30 * time.Second
)
var (
@@ -82,6 +87,7 @@ type UserRepository interface {
ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UserListFilters) ([]User, *pagination.PaginationResult, error)
GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error)
GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error)
+ UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error
UpdateBalance(ctx context.Context, id int64, amount float64) error
DeductBalance(ctx context.Context, id int64, amount float64) error
@@ -192,6 +198,8 @@ type UserService struct {
settingRepo SettingRepository
authCacheInvalidator APIKeyAuthCacheInvalidator
billingCache BillingCache
+ lastActiveTouchL1 sync.Map
+ lastActiveTouchSF singleflight.Group
}
// NewUserService 创建用户服务实例
@@ -788,6 +796,66 @@ func (s *UserService) GetByID(ctx context.Context, id int64) (*User, error) {
return user, nil
}
+// TouchLastActive 通过防抖更新 users.last_active_at,减少鉴权热路径写放大。
+// 该操作为尽力而为,不应中断正常请求。
+func (s *UserService) TouchLastActive(ctx context.Context, userID int64) {
+ if s == nil || s.userRepo == nil || userID <= 0 {
+ return
+ }
+
+ user, err := s.userRepo.GetByID(ctx, userID)
+ if err != nil {
+ slog.Debug("skip touch user last active after load failure", "user_id", userID, "error", err)
+ return
+ }
+ s.TouchLastActiveForUser(ctx, user)
+}
+
+// TouchLastActiveForUser 使用已加载的用户信息更新 last_active_at,避免重复读取数据库。
+func (s *UserService) TouchLastActiveForUser(ctx context.Context, user *User) {
+ if s == nil || s.userRepo == nil || user == nil || user.ID <= 0 {
+ return
+ }
+
+ now := time.Now()
+ if userLastActiveFresh(user.LastActiveAt, now) {
+ return
+ }
+ if v, ok := s.lastActiveTouchL1.Load(user.ID); ok {
+ if nextAllowedAt, ok := v.(time.Time); ok && now.Before(nextAllowedAt) {
+ return
+ }
+ }
+
+ _, err, _ := s.lastActiveTouchSF.Do(strconv.FormatInt(user.ID, 10), func() (any, error) {
+ latest := time.Now()
+ if v, ok := s.lastActiveTouchL1.Load(user.ID); ok {
+ if nextAllowedAt, ok := v.(time.Time); ok && latest.Before(nextAllowedAt) {
+ return nil, nil
+ }
+ }
+ if userLastActiveFresh(user.LastActiveAt, latest) {
+ return nil, nil
+ }
+ if err := s.userRepo.UpdateUserLastActiveAt(ctx, user.ID, latest); err != nil {
+ s.lastActiveTouchL1.Store(user.ID, latest.Add(userLastActiveFailBackoff))
+ return nil, fmt.Errorf("touch user last active: %w", err)
+ }
+ s.lastActiveTouchL1.Store(user.ID, latest.Add(userLastActiveMinTouch))
+ return nil, nil
+ })
+ if err != nil {
+ slog.Warn("touch user last active failed", "user_id", user.ID, "error", err)
+ }
+}
+
+func userLastActiveFresh(lastActiveAt *time.Time, now time.Time) bool {
+ if lastActiveAt == nil {
+ return false
+ }
+ return now.Before(lastActiveAt.Add(userLastActiveMinTouch))
+}
+
func (s *UserService) hydrateUserAvatar(ctx context.Context, user *User) error {
if s == nil || s.userRepo == nil || user == nil || user.ID == 0 {
return nil
diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go
index d771cb75..2c11f8ec 100644
--- a/backend/internal/service/user_service_test.go
+++ b/backend/internal/service/user_service_test.go
@@ -23,18 +23,21 @@ import (
// --- mock: UserRepository ---
type mockUserRepo struct {
- updateBalanceErr error
- updateBalanceFn func(ctx context.Context, id int64, amount float64) error
- getByIDUser *User
- getByIDErr error
- updateFn func(ctx context.Context, user *User) error
- updateCalls int
- upsertAvatarFn func(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error)
- upsertAvatarArgs []UpsertUserAvatarInput
- deleteAvatarFn func(ctx context.Context, userID int64) error
- deleteAvatarIDs []int64
- getAvatarFn func(ctx context.Context, userID int64) (*UserAvatar, error)
- txCalls int
+ updateBalanceErr error
+ updateBalanceFn func(ctx context.Context, id int64, amount float64) error
+ getByIDUser *User
+ getByIDErr error
+ updateLastActiveErr error
+ updateLastActiveUserIDs []int64
+ updateLastActiveAt []time.Time
+ updateFn func(ctx context.Context, user *User) error
+ updateCalls int
+ upsertAvatarFn func(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error)
+ upsertAvatarArgs []UpsertUserAvatarInput
+ deleteAvatarFn func(ctx context.Context, userID int64) error
+ deleteAvatarIDs []int64
+ getAvatarFn func(ctx context.Context, userID int64) (*UserAvatar, error)
+ txCalls int
}
type mockUserRepoTxKey struct{}
@@ -144,6 +147,11 @@ func (m *mockUserRepo) UpdateBalance(ctx context.Context, id int64, amount float
}
return m.updateBalanceErr
}
+func (m *mockUserRepo) UpdateUserLastActiveAt(_ context.Context, userID int64, activeAt time.Time) error {
+ m.updateLastActiveUserIDs = append(m.updateLastActiveUserIDs, userID)
+ m.updateLastActiveAt = append(m.updateLastActiveAt, activeAt)
+ return m.updateLastActiveErr
+}
func (m *mockUserRepo) DeductBalance(context.Context, int64, float64) error { return nil }
func (m *mockUserRepo) UpdateConcurrency(context.Context, int64, int) error { return nil }
func (m *mockUserRepo) ExistsByEmail(context.Context, string) (bool, error) { return false, nil }
@@ -288,6 +296,39 @@ func TestUpdateBalance_CacheFailure_DoesNotAffectReturn(t *testing.T) {
}, 2*time.Second, 10*time.Millisecond, "即使失败也应调用 InvalidateUserBalance")
}
+func TestTouchLastActive_UpdatesWhenStale(t *testing.T) {
+ stale := time.Now().Add(-11 * time.Minute)
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 42,
+ LastActiveAt: &stale,
+ },
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ svc.TouchLastActive(context.Background(), 42)
+
+ require.Equal(t, []int64{42}, repo.updateLastActiveUserIDs)
+ require.Len(t, repo.updateLastActiveAt, 1)
+ require.WithinDuration(t, time.Now(), repo.updateLastActiveAt[0], 2*time.Second)
+}
+
+func TestTouchLastActive_SkipsWhenRecent(t *testing.T) {
+ recent := time.Now().Add(-time.Minute)
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 42,
+ LastActiveAt: &recent,
+ },
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ svc.TouchLastActive(context.Background(), 42)
+
+ require.Empty(t, repo.updateLastActiveUserIDs)
+ require.Empty(t, repo.updateLastActiveAt)
+}
+
func TestUpdateBalance_RepoError_ReturnsError(t *testing.T) {
repo := &mockUserRepo{updateBalanceErr: errors.New("database error")}
cache := &mockBillingCache{}
diff --git a/frontend/src/views/admin/UsersView.vue b/frontend/src/views/admin/UsersView.vue
index 07c9d437..93cfdbbe 100644
--- a/frontend/src/views/admin/UsersView.vue
+++ b/frontend/src/views/admin/UsersView.vue
@@ -455,12 +455,6 @@
{{ formatDateTime(value) }}
-
-
- {{ value ? formatDateTime(value) : '-' }}
-
-
-
{{ value ? formatDateTime(value) : '-' }}
@@ -718,7 +712,6 @@ const allColumns = computed(() => [
{ key: 'usage', label: t('admin.users.columns.usage'), sortable: false },
{ key: 'concurrency', label: t('admin.users.columns.concurrency'), sortable: true },
{ key: 'status', label: t('admin.users.columns.status'), sortable: true },
- { key: 'last_login_at', label: t('admin.users.columns.lastLogin'), sortable: true },
{ key: 'last_used_at', label: t('admin.users.columns.lastUsed'), sortable: true },
{ key: 'last_active_at', label: t('admin.users.columns.lastActive'), sortable: true },
{ key: 'created_at', label: t('admin.users.columns.created'), sortable: true },
@@ -735,7 +728,9 @@ const toggleableColumns = computed(() =>
const hiddenColumns = reactive>(new Set())
// Default hidden columns (columns hidden by default on first load)
-const DEFAULT_HIDDEN_COLUMNS = ['notes', 'groups', 'subscriptions', 'usage', 'concurrency', 'last_login_at', 'last_active_at']
+const DEFAULT_HIDDEN_COLUMNS = ['notes', 'groups', 'subscriptions', 'usage', 'concurrency']
+const REMOVED_COLUMNS = new Set(['last_login_at'])
+const FORCED_VISIBLE_COLUMNS = new Set(['last_active_at'])
// localStorage key for column settings
const HIDDEN_COLUMNS_KEY = 'user-hidden-columns'
@@ -746,7 +741,9 @@ const loadSavedColumns = () => {
const saved = localStorage.getItem(HIDDEN_COLUMNS_KEY)
if (saved) {
const parsed = JSON.parse(saved) as string[]
- parsed.forEach(key => hiddenColumns.add(key))
+ parsed
+ .filter(key => !REMOVED_COLUMNS.has(key) && !FORCED_VISIBLE_COLUMNS.has(key))
+ .forEach(key => hiddenColumns.add(key))
} else {
// Use default hidden columns on first load
DEFAULT_HIDDEN_COLUMNS.forEach(key => hiddenColumns.add(key))
@@ -808,7 +805,7 @@ const searchQuery = ref('')
const USER_SORT_STORAGE_KEY = 'admin-users-table-sort'
const loadInitialSortState = (): { sort_by: string; sort_order: 'asc' | 'desc' } => {
const fallback = { sort_by: 'created_at', sort_order: 'desc' as 'asc' | 'desc' }
- const sortable = new Set(['email', 'id', 'username', 'role', 'balance', 'concurrency', 'status', 'last_login_at', 'last_used_at', 'last_active_at', 'created_at'])
+ const sortable = new Set(['email', 'id', 'username', 'role', 'balance', 'concurrency', 'status', 'last_used_at', 'last_active_at', 'created_at'])
try {
const raw = localStorage.getItem(USER_SORT_STORAGE_KEY)
if (!raw) return fallback
diff --git a/frontend/src/views/admin/__tests__/UsersView.spec.ts b/frontend/src/views/admin/__tests__/UsersView.spec.ts
index 1ea67b63..d9076777 100644
--- a/frontend/src/views/admin/__tests__/UsersView.spec.ts
+++ b/frontend/src/views/admin/__tests__/UsersView.spec.ts
@@ -113,7 +113,7 @@ describe('admin UsersView', () => {
getBatchUserAttributes.mockResolvedValue({ values: {} })
})
- it('shows last_used_at column and requests last_used_at sort', async () => {
+ it('shows active and used activity columns, hides last_login_at, and requests last_used_at sort', async () => {
const wrapper = mount(UsersView, {
global: {
stubs: {
@@ -144,7 +144,10 @@ describe('admin UsersView', () => {
await flushPromises()
- expect(wrapper.get('[data-test="columns"]').text()).toContain('last_used_at')
+ const columns = wrapper.get('[data-test="columns"]').text()
+ expect(columns).toContain('last_used_at')
+ expect(columns).toContain('last_active_at')
+ expect(columns).not.toContain('last_login_at')
await wrapper.get('[data-test="sort-last-used"]').trigger('click')
await flushPromises()
--
GitLab
From 49258dd3f6dfaab31589c797c033620c498a21d3 Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 14:55:07 +0800
Subject: [PATCH 149/261] fix: preserve scheduler transport compatibility
defaults
---
backend/internal/service/openai_account_scheduler.go | 3 +++
1 file changed, 3 insertions(+)
diff --git a/backend/internal/service/openai_account_scheduler.go b/backend/internal/service/openai_account_scheduler.go
index 38b92b47..5fda3abd 100644
--- a/backend/internal/service/openai_account_scheduler.go
+++ b/backend/internal/service/openai_account_scheduler.go
@@ -767,6 +767,9 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
}
func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool {
+ if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE {
+ return true
+ }
if s == nil || s.service == nil {
return false
}
--
GitLab
From 78f691d2de24d0d13ce68922e120c8119ea32856 Mon Sep 17 00:00:00 2001
From: shaw
Date: Tue, 21 Apr 2026 12:13:45 +0800
Subject: [PATCH 150/261] chore: update sponsors
---
README.md | 5 +++++
README_CN.md | 5 +++++
README_JA.md | 5 +++++
assets/partners/logos/bestproxy.png | Bin 0 -> 9716 bytes
4 files changed, 15 insertions(+)
create mode 100644 assets/partners/logos/bestproxy.png
diff --git a/README.md b/README.md
index bee2e8c3..3e609d65 100644
--- a/README.md
+++ b/README.md
@@ -96,6 +96,11 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
Huge thanks to BmoPlus for sponsoring this project! BmoPlus is a highly reliable AI account provider built strictly for heavy AI users and developers. They offer rock-solid, ready-to-use accounts and official top-up services for ChatGPT Plus / ChatGPT Pro (Full Warranty) / Claude Pro / Super Grok / Gemini Pro. By registering and ordering through BmoPlus - Premium AI Accounts & Top-ups , users can unlock the mind-blowing rate of 10% of the official GPT subscription price (90% OFF)
+
+
+Thanks to Bestproxy for sponsoring this project! Bestproxy provides high-purity residential IPs with dedicated one-IP-per-account support. By combining real home networks with fingerprint isolation, it enables link environment isolation and reduces the probability of association-based risk control.
+
+
## Ecosystem
diff --git a/README_CN.md b/README_CN.md
index 892eee61..add32a17 100644
--- a/README_CN.md
+++ b/README_CN.md
@@ -95,6 +95,11 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的
感谢 BmoPlus 赞助了本项目!BmoPlus 是一家专为AI订阅重度用户打造的可靠 AI 账号代充服务商,提供稳定的 ChatGPT Plus / ChatGPT Pro(全程质保) / Claude Pro / Super Grok / Gemini Pro 的官方代充&成品账号。 通过BmoPlus AI成品号专卖/代充 注册下单的用户,可享GPT 官网订阅一折 的震撼价格!
+
+
+感谢 Bestproxy 赞助了本项目!Bestproxy 是一家提供高纯度住宅IP,支持一号一IP独享,结合真实家庭网络与指纹隔离,可实现链路环境隔离,降低关联风控概率。
+
+
## 生态项目
diff --git a/README_JA.md b/README_JA.md
index 6f0fc900..ccd595b9 100644
--- a/README_JA.md
+++ b/README_JA.md
@@ -95,6 +95,11 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを
本プロジェクトにご支援いただいた BmoPlus に感謝いたします!BmoPlusは、AIサブスクリプションのヘビーユーザー向けに特化した信頼性の高いAIアカウントサービスプロバイダーであり、安定した ChatGPT Plus / ChatGPT Pro (完全保証) / Claude Pro / Super Grok / Gemini Pro の公式代行チャージおよび即納アカウントを提供しています。こちらのBmoPlus AIアカウント専門店/代行チャージ 経由でご登録・ご注文いただいたユーザー様は、GPTを 公式サイト価格の約1割(90% OFF) という驚異的な価格でご利用いただけます!
+
+
+Bestproxy のご支援に感謝します!Bestproxy は高純度の住宅IPを提供し、1アカウント1IP専有をサポートしています。実際の家庭ネットワークとフィンガープリント分離を組み合わせることで、リンク環境の分離を実現し、関連付けによるリスク管理の確率を低減します。
+
+
## エコシステム
diff --git a/assets/partners/logos/bestproxy.png b/assets/partners/logos/bestproxy.png
new file mode 100644
index 0000000000000000000000000000000000000000..87c586705020a09a07484d0b0cb9f3bff79a219d
GIT binary patch
literal 9716
zcmd6NXE>bQ*Y9m2j6Mko5n_lMWz-mLFuEv_kRUpvL}&Eg1<~^8HADzOw1^&^=tM-c
z=!EFKo?DXh{GaDN?}u}p^WmJiu5s_ZXYIAu-h2I4yFI(FB1b|7?Rn$8(~5ZK9UIRXG_%f%Nbj-B)-09=-}P}jt2Dk}*a+gfuOn%Ej)xLmF6
zKxzOG6?e5WG`7TG;YJuU3mY+(^_m72xP^%r%VPm$Ze=?ujJbupy94H-yNbH8yQQ&^
z35z(2P}EfzWMGZK8p2(zt!x~HUBy^_@f8N!7sCh^kjTNrRQQ3k>>m=aC&prq#o7rY
z5H2n*TrPZEwhm?p9w8wi1UD~&mzNWy;B<7e!5X@9+Bn_@IS|5s@sP$i8ar6nVJ&QJ
z;1@g%jclE;Vk|6#@IS02u~-KSqrbV^IC2^O8gaUq!Ub*`VQAur;NjxF5DN|${Y@_{
zY2tX%{k_OTjQy|9@1B*#ua)c!9UL+0ZgvfW43LqvH8V`p?LqM
z`_I@fT2aKsC0behKKoytKZpPKjNdo&oA4jE|G%#hCdPkjgPoIu)h|smF-BmlFxD6w
z?1e7y{LuvyV_~cX)(Z0m
ztBaa13W6)#hcA?5uGdD)XGG3_@Lw#gpAQL!FO=oY2S+UXCNDLFEJlXFg#?(`;hb>r
zZ2p%s__M4i;vW+IUz~v5Kj3x&h+n9`fCBgfxP87q+Xs8Xl1^B2TL;wKhk&^c`AcE3
z3+buzuaf8A#Vp5`h+feDPYMu^zvTWux)RV}{Cy|?63Eri?jP6$$%KDl%GN>M*47Fo
zWoYAU2xonOu`$6oU>w-sAZEn=ffM)zHpKpc2ly`xi2c%gQN({`_up>e_pU%w1sW~H
z-=+=h{B0b;)x}sGK-2e5zOw){n<^Mh1yf~Z;11Y^0=PKTz-6$71Ab9BH~((S;;;gj
zevN|`8bbc3eL?m6V*oxbdcU47m!}p1fbK{^T2kFL^=FzJmZH7l+;*2Yg+Dt+c2FKY
zW{%HM5+(yNxxvlTf|Eiurj@pRxr%6fus#5zM
zX3X-{7YE1m#T7WC-=lt1@(qi4w!?)_#6n>Z
zVmv7iyYnA}+*A=hBuXd5!Lo)kX59RQ%#;%LB5l?%cIdx@Eqf%&aC9$0cZ(dEu1#x
zsNmKpgRcaXeUeB;Xp-zFC_o;a%{UW4q`dUmJF$%Tk(Ga{k%TBPBtFY5mp#LROL4<}
z?4uY|Gfl6>WJmv19su07LGB4ISp+$eNw^6>t9|sWZ$zjp74lx-YH&;6e1Rxt#-;KHi=)!r3cbhyKLW*7&q;VK<+bHqh@s4;o
zOMEa6ES35U`SBnTO^t$$>q*|pzKtK)`cn@GZ~|pTHX!&}(CT|91Q5IN9cG}Q|BJ*o
z75^;m#l8<(1LhwQ*Zt;*6Mb$1$$GU=iJ1Nl7_n1;k4b40fyCFS-dn+gtjf>t;oP?x$keZK@i)IC@9xa5Bfy(oKs(Ol
z!bx#A6@#SZc|Vr!ehh!NEwoX^LcBrX?~5;b8vVrC;+Tgv5ay8c5`uI
z8KI4gAbs0{e?>#-xe`QRP?eg5MYrq$^9tr=ECVopU%*T$+_}|9pdbHhB6Kv!?u?c7
zV0HsV?@JJPLPuG{@Y&kEZnFK>u}@DA=TOys6Mkyz_j6lYLIMKA_u@H}RNgm~0w&$2>cSsHPShmRYMeu}Kuk^xkz+<{Q~vQeG;(fUtn9ES7ftExt6
z+Idc=zH@Kt^|!pmXwV!{k@Wh$KvKJi_7Z4%4_A7
zN*k|!8`S_+y>sB
zzVu7O!}ls}PTRKT$hE%ktbAc=L124}}?xCMt4>XrJ<%jX(%Z
z2r@E!%-_|pdj0{OVRxvG>H1#_s;H+w+LHo3phOI`+`f
z(Q$C-pgjo8^KH{iye
zd0Lt+%4o$ecC+bzu)FmWXW&avb4+@AeR?{oa_n}OqC+VYA7AryN2N9=O?xXlEWGvG
zH{s@KsDP(z_RCE8lnHN}2Z@;2X+VI=QH29dLPiFAm!hgFqW^x;M@mY{NLxy=J1}T@
zrOTUqLLfh^tgNiGtgPtRvOZ=Y?BLB9)sLXGzK?d|r`XGmkvxGA*oMC_SFa!;`v3lUFr5=^5uIhh=6@nx;E8HlV}4L(A?z-efFet6wp8sd{g@V&?r7
z(wGAzZA&zjm~Yqh)4IIfb@viK>^u)d8v>uF#F4x+Ot=&>Zu=7xJ_pS!r5r^58y_z{
z$2Dm7jlCFvDkeSa4-UwMu8rfIFfDl>bF6M^27R
z%H2Z!vV|{!xa5FAgfWlBNKu__-7XhB6>ycDyy5)Rwn+C0J5-duWo7qt{T$`7V=h*x
zXJ~l#LKQXe?c3qbP~X8RPDBYjRQzb*cIA|}<%#n^tdo^83?hod>ob~50z$!NznH`
zO=vhN99hxOh=aA*vtE7XURMu(%FI$!Q_WpZHsz$%KzwlNtu|_Ev@P&qv~E#R;rf>I
zr!Xp)9oP?M>gJM{mGVZ&O0iPB4(TdI6X6BCTw6=mR2p1v2@Y=aIwqwR)24eS4W-8m
z7_g}v`XZA(us=1mzgmV!7oWNwt^j?m#{inK2x1qP{Zvu6h1JpckXuJ#J@Ni9-lPbZ
z@sfp&X=4_Z)E25snfm8PIkmN^gM$ocgOMUJkNsO4uE9KQre*WTQw_J-*#j2c!du6`
zepP}p2i9)3MaqOPPx|=iKYYkXYtNxuo%isa)RUWv6c<8Opt9Ki-r7eRYCUmefc95_1qwMT
z(z!2nUAojuMClGyi=q`doCz-$Jv-Qk_x-3DA*SsixK3ZXEDWlfi+pAsn
z`t{4{h$%Y2Hwq;@JnqRQ<|X54Ff}6BKR$k^ix9XXGf`Ms)X}ZfP>js5o!bw|A->jUKSv7g7r
zM#jd-fJc^tl=XIfX5z;;B&kj#L0`vSTj00&@H
zKsn`a&L>O}hcuT`#m*qszA=K_o@8b(a-cwNM?dnpa5MXN6*3kg~DXF5xjB~MJ
zZu;mnG0W`*5lSwOk!37+r9Nld`XFxGHFu9M+16J%Zq8h;oOmA=8TZuk@%u^_vC|#L
zMuTYB7Pw8yYuEDf9Tt@M06Urn*UhFrueCA-h3nq}cUMMirs~UHzPv)e;VSFAKB=!;
zr%oXQ+^480DXA(f3`lMmA78!39zTfJD-R+7H&9|l?mHcx!^D8wQ$A#c<-p<5b0_cY
zbu(}As1NJRkEsWC$7Ru>S*(SH*4Eavo16e6%7XHl^y!(zgV>GRw>;Dp)wPH^oI3~Ah1i#vgv4t@aw3+A=uNte
z+IT7EpwC38&2Q;W8_Qdsy?r9i$1oOyM@4K5H${-LHp`rDNKn;`x)cjG<
z&(6|`Yu}H?2Xu8OD7kiEXxI$s(6Pdxo+rH?EV$D~B|_syLSlv_6-|#RDhe+%OAw~1
zdiLyDqXEo#H-HSZ1%9B)sH>01#>K^?T2+h|ySER^jb@Ru#-@m~U2WQwMDc2%ROA
z!yIng*Lms+$?M_sXUBrYS`VeRDcWt2J}ZS9wsjR{IXTzp2dBh8cZ&D)2*a0vEg3^%O1&TQ4IcdlNw@fyKZG=MG(^5
zXkfQKSv^{Ium1FBF%x5=4%dE1XaWo!>VD#SI?~{Momq)LQP1*s
zm4HB(zK1n_yZyA?tD?1LFy0s%d_alNdxu?aN>0w7nL*zn*4|g(ln4!9PriFY$h&yP
zO@jGh)7E@Ao!eZBj7;PkWm%|(dUqOMuGIsY#&xo49UBdei_596cDGOrb2=HsOJ7t}
z414>ipq-BxP|59D*=_#HDB5lhP7a|v({^#Ob#YnO^rU!HsC^cM@O+dz^t>2oV-y#+
zov0oL^O{&jFh-cD*`M;-H!Ce&4I!m<64u=*U}I;0`{qsZsKO^c{S);;vBH-1;DHgF
zsAdy957G@$Ae!pYgoCu-fF{KZPe$I>qtT>kRQkwxiV$o7p9h}r2>^kiioL9Bx-WZ+
znKhRBbEeMtNT7fc^Q^9mwPu4ir*bWQpfrP8uH({3!{px4iU>nYD5Ju9-G1YQ%ha;n
zoD74SQ@6WJ_5*%lA&BZO(a!FxFf12fJwQy#W`Ai^&a
zw=accMOqKDNA&x~8*{q;0MXNz&P1WYfW)0?h-OH83mMJyt|rVWHcv!#r(xT{qlvO9j9=}03qpoW
zn#)Fs<9)XD*|#*@ILYy@B&FB2wd3HbKPT(NuRa67keCopil&{3nHeMg->VM&yA8l`@d8lvjV5oSHq3Rj7P85LeXFV*bkPYPLUrV2q%yG*(@l$0=ed!xTU7o045t5a@1P`gQ)
z6QA|r!@Fwt&em|TQjTNKGpUfsAG_6HlyCkfi$P9fIJThSrXU-i&Iy0Gw+R&Q8xB2W
zHh2FD75Pi-6%(Nu!kJa>0sP3r19Ys8q5Qdv%f!l=soVH>KLC)^R$q^%E+ry(tn@5R
z(X(82_o{gqP$HXv0{j*q0G_9uUrNEknvxji|OP=sd{=0YGM4QW%oomF{Iw9E;t`R&a
z=hf_~(={?biD*-rj#4xf5`Js4khkQo4e4`Q1L!z+}LBN;v)`^{oNz3B4
zJq^62Sj%zctyU8MwIglYF69;_8}J1a&RdCnAv-c9m0o;Er*j8a5;v)YFZ1cZPtlnz
ztOA)9<=n~(7O^@_rryMc$6KmBmuek;LZ8VRPunq45Q6t)z%9x@Cu~16_r$i2PgE4$
z#mJOIel-HPuesFG*OzlMCFP0Fc8WD8=i<_(Fqj;4XK&AX>daYHRbRA}1kfvX=}W&z
zIs+s#k}UKaynliTPl^4kdz~IGKz%Da2|d!t$TT#)Uy&jXnpV&X*_1I2aUSo{j@NtT
zPSItK*2gmX987PH=y`N4ZW|M4vG!*8+hx`^=HxEalEx>34WPZZ#JqCCx+as-E{o+
z?FtQHke>T;oNf&&lPd-W#S8HAIt633QjSXiI>
z^l8<+k8h)%d0ej9?J`j@-SP}Csy%?`W>t}`d@#z^E$}R&rU{INxlQSwSTT4qzsje~aX4B{&cDUUo
z6q#|gvVrEaiZ1?U^ql(JFR$?FP9Es%pC|G4vBuj-e+#nI+2~G_WbuX*
zk13D5YIg$KBZ!2%`w!5O*-U!o_L$BhtG)F=SBJ{TD5?`oQUU|H2bG90Bg}jgn7NVQ
zWoNP8>a%+|R0W338K?;g!XOznRyH>GgXv-+nkr~3?SXmxz=2=0X3#P_(0@?(*loqU
z#FhPZ%}pqMxq)QsY?P1B;T98KK+J432Ya~%cu>Bs=`P5vhTXzq&&}}t!BG*dN0D^r
zORbdWOU`7Md>b1b!2k2wWyyD%-Jlb{b7$d^1}EDchnA7l
zfkgpv@qIA8UJJX{weo|MPTae+*p{1$cnKQ})@#8~3@pn&F3kVb)AKsy^Jn$N`T5le
zkBw%ssJnOHf6!Ue(sqR{R@EXW+6HenQatvtH)LfERlOC3qlNuoZCKpQ}Ryx1z24z&ht5X1gT7^|3SfVoloqkwqNUQ&s%-mh41q7nO_V%
zt$3ZAo2I*l3Hzi@Es_maLc-^iUUyfrl@230Q<7#$f|
z1lj?~hYyJW3QEfQjix@ud*K2MXGyDvC4`*q@YS0cVR(Qq%w$eVg$-Z;xZg^jQLkaA
zJIeKuIH6Z_P~!~<>_DzGOYsT|98#$%uS%};SWsKmAFkxi*T+vuyOONpML1ixHfO@_
zRmbwOF}Q@Sc~fp6CP=Tiue;pL^xN#^whZeNp9tC8%1r6yf@eIpo6}=ryq}#WSe5p~
zqvaSLi$<(v6$F4Kv?ce&4{pg=sU3GO&m-e++l)+F9M`u$|D<>-MMf*^eCElHmWH+b
z_>n^@H#z6nqhtE^fB^3tiWP{_#D;kzXISM3uETMC18&FE)47q6=F$6a?o6`=hGxOE
z(iSjYP)xYO%2mRi+B53d<08i76;>D>wbBG?azr=(tA^eY@0FS^$K85tR?{2Z8k3=-
zABV~^2
z6e(mdO0Rv=xIAeHGu~anNyHwJ5JWPyabSNgi=DF;DLl?FS3aBe9K2RkY>V`&CIFy*
zASCDK9cPR0eunTpIw{4aOb+#-QvQ)g4!IE&olX*Z$v;zdi;i$#)`&-`h;!&N&Q+%-
z?T3?J=2{6(uKApX*cca
zA8uB;e#9yLMqd-^&2US=LD8N~`r3ysPRm9K?zYfb+%%QqtWl#+i=36?)@(f#-F$8p
z`w@8b{65}?oR!_L{b<0{`d;JN@a|DAuT-n*cO&wzWh}%8S72%(3tX4(5mw-v-zOm<
zdg>i^O33cOeI+|TE);Y0xwWnPiEcBpo4N7-&-c{-?ThlQbDSH~N@aH*^zwlJ5dbL2s7M!~4gCKH;L*^k
literal 0
HcmV?d00001
--
GitLab
From c624cce88e2ccf48a410056192992366499c757d Mon Sep 17 00:00:00 2001
From: IanShaw027
Date: Tue, 21 Apr 2026 15:56:30 +0800
Subject: [PATCH 151/261] fix: unblock auth identity compat backfill migration
---
...entity_compat_backfill_integration_test.go | 73 +++++++++++++++++++
.../internal/repository/migrations_runner.go | 6 ++
.../migrations_runner_checksum_test.go | 9 +++
.../109_auth_identity_compat_backfill.sql | 3 +
4 files changed, 91 insertions(+)
create mode 100644 backend/internal/repository/auth_identity_compat_backfill_integration_test.go
diff --git a/backend/internal/repository/auth_identity_compat_backfill_integration_test.go b/backend/internal/repository/auth_identity_compat_backfill_integration_test.go
new file mode 100644
index 00000000..56b37512
--- /dev/null
+++ b/backend/internal/repository/auth_identity_compat_backfill_integration_test.go
@@ -0,0 +1,73 @@
+//go:build integration
+
+package repository
+
+import (
+ "context"
+ "os"
+ "path/filepath"
+ "strconv"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestAuthIdentityCompatBackfillMigration_AllowsLongReportTypes(t *testing.T) {
+ tx := testTx(t)
+ ctx := context.Background()
+
+ migration108Path := filepath.Join("..", "..", "migrations", "108_auth_identity_foundation_core.sql")
+ migration108SQL, err := os.ReadFile(migration108Path)
+ require.NoError(t, err)
+
+ migration109Path := filepath.Join("..", "..", "migrations", "109_auth_identity_compat_backfill.sql")
+ migration109SQL, err := os.ReadFile(migration109Path)
+ require.NoError(t, err)
+
+ _, err = tx.ExecContext(ctx, `
+DROP TABLE IF EXISTS auth_identity_migration_reports CASCADE;
+DROP TABLE IF EXISTS auth_identity_channels CASCADE;
+DROP TABLE IF EXISTS identity_adoption_decisions CASCADE;
+DROP TABLE IF EXISTS pending_auth_sessions CASCADE;
+DROP TABLE IF EXISTS auth_identities CASCADE;
+
+ALTER TABLE users
+ DROP COLUMN IF EXISTS signup_source,
+ DROP COLUMN IF EXISTS last_login_at,
+ DROP COLUMN IF EXISTS last_active_at;
+`)
+ require.NoError(t, err)
+
+ _, err = tx.ExecContext(ctx, string(migration108SQL))
+ require.NoError(t, err)
+
+ var userID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('oidc-demo-subject@oidc-connect.invalid', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&userID))
+
+ _, err = tx.ExecContext(ctx, string(migration109SQL))
+ require.NoError(t, err)
+
+ var reportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'oidc_synthetic_email_requires_manual_recovery'
+ AND report_key = $1
+`, strconv.FormatInt(userID, 10)).Scan(&reportCount))
+ require.Equal(t, 1, reportCount)
+
+ var reportTypeLimit int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT character_maximum_length
+FROM information_schema.columns
+WHERE table_schema = 'public'
+ AND table_name = 'auth_identity_migration_reports'
+ AND column_name = 'report_type'
+`).Scan(&reportTypeLimit))
+ require.GreaterOrEqual(t, reportTypeLimit, 45)
+
+ require.NotZero(t, userID)
+}
diff --git a/backend/internal/repository/migrations_runner.go b/backend/internal/repository/migrations_runner.go
index 9cf3b392..5a2e6677 100644
--- a/backend/internal/repository/migrations_runner.go
+++ b/backend/internal/repository/migrations_runner.go
@@ -73,6 +73,12 @@ var migrationChecksumCompatibilityRules = map[string]migrationChecksumCompatibil
"222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3": {},
},
},
+ "109_auth_identity_compat_backfill.sql": {
+ fileChecksum: "551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee",
+ acceptedDBChecksum: map[string]struct{}{
+ "2b380305e73ff0c13aa8c811e45897f2b36ca4a438f7b3e8f98e19ecb6bae0b3": {},
+ },
+ },
}
// ApplyMigrations 将嵌入的 SQL 迁移文件应用到指定的数据库。
diff --git a/backend/internal/repository/migrations_runner_checksum_test.go b/backend/internal/repository/migrations_runner_checksum_test.go
index 6c3ad725..6030991b 100644
--- a/backend/internal/repository/migrations_runner_checksum_test.go
+++ b/backend/internal/repository/migrations_runner_checksum_test.go
@@ -51,4 +51,13 @@ func TestIsMigrationChecksumCompatible(t *testing.T) {
)
require.False(t, ok)
})
+
+ t.Run("109历史checksum可兼容", func(t *testing.T) {
+ ok := isMigrationChecksumCompatible(
+ "109_auth_identity_compat_backfill.sql",
+ "2b380305e73ff0c13aa8c811e45897f2b36ca4a438f7b3e8f98e19ecb6bae0b3",
+ "551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee",
+ )
+ require.True(t, ok)
+ })
}
diff --git a/backend/migrations/109_auth_identity_compat_backfill.sql b/backend/migrations/109_auth_identity_compat_backfill.sql
index ddbbedbc..5147ae45 100644
--- a/backend/migrations/109_auth_identity_compat_backfill.sql
+++ b/backend/migrations/109_auth_identity_compat_backfill.sql
@@ -1,3 +1,6 @@
+ALTER TABLE auth_identity_migration_reports
+ALTER COLUMN report_type TYPE VARCHAR(80);
+
INSERT INTO auth_identities (
user_id,
provider_type,
--
GitLab
From 59290e39f99ad61781c21b2f197c26e7b81278eb Mon Sep 17 00:00:00 2001
From: erio
Date: Tue, 21 Apr 2026 17:18:37 +0800
Subject: [PATCH 152/261] chore(channels): drop admin-side available channels
view
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Remove the admin-side "Available Channels" aggregate view — admins
already see full channel configuration (groups, pricing, model
mappings) in the channel edit dialog, making a read-only admin
aggregate view redundant. The user-side "可用渠道" remains.
Backend:
- Delete handler/admin/available_channel_handler.go (+ test)
- Drop AdminHandlers.AvailableChannel field and wire injection
- Remove /admin/channels/available route
Frontend:
- Delete views/admin/AvailableChannelsView.vue
- Drop /admin/available-channels router entry
- Strip AvailableChannel types + listAvailable from api/admin/channels.ts
---
backend/cmd/server/wire_gen.go | 3 +-
.../admin/available_channel_handler.go | 95 ----------
.../admin/available_channel_handler_test.go | 57 ------
backend/internal/handler/handler.go | 1 -
backend/internal/handler/wire.go | 3 -
backend/internal/server/routes/admin.go | 1 -
frontend/src/api/admin/channels.ts | 39 +----
frontend/src/router/index.ts | 12 --
.../src/views/admin/AvailableChannelsView.vue | 164 ------------------
9 files changed, 2 insertions(+), 373 deletions(-)
delete mode 100644 backend/internal/handler/admin/available_channel_handler.go
delete mode 100644 backend/internal/handler/admin/available_channel_handler_test.go
delete mode 100644 frontend/src/views/admin/AvailableChannelsView.vue
diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go
index 7568fa50..9028210c 100644
--- a/backend/cmd/server/wire_gen.go
+++ b/backend/cmd/server/wire_gen.go
@@ -175,7 +175,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
digestSessionStore := service.NewDigestSessionStore()
channelRepository := repository.NewChannelRepository(db)
channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator)
- availableChannelHandler := admin.NewAvailableChannelHandler(channelService)
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService)
@@ -236,7 +235,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService)
availableChannelUserHandler := handler.NewAvailableChannelHandler(channelService, apiKeyService)
- adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, availableChannelHandler, paymentHandler)
+ adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, paymentHandler)
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
diff --git a/backend/internal/handler/admin/available_channel_handler.go b/backend/internal/handler/admin/available_channel_handler.go
deleted file mode 100644
index 45b8f357..00000000
--- a/backend/internal/handler/admin/available_channel_handler.go
+++ /dev/null
@@ -1,95 +0,0 @@
-package admin
-
-import (
- "github.com/Wei-Shaw/sub2api/internal/pkg/response"
- "github.com/Wei-Shaw/sub2api/internal/service"
-
- "github.com/gin-gonic/gin"
-)
-
-// AvailableChannelHandler 处理「可用渠道」聚合视图的管理员接口。
-//
-// 该视图以只读方式聚合渠道基础信息、关联分组与推导出的支持模型列表(无通配符)。
-type AvailableChannelHandler struct {
- channelService *service.ChannelService
-}
-
-// NewAvailableChannelHandler 创建 AvailableChannelHandler 实例。
-func NewAvailableChannelHandler(channelService *service.ChannelService) *AvailableChannelHandler {
- return &AvailableChannelHandler{channelService: channelService}
-}
-
-// availableGroupResponse 响应中的分组概要。
-type availableGroupResponse struct {
- ID int64 `json:"id"`
- Name string `json:"name"`
- Platform string `json:"platform"`
-}
-
-// supportedModelResponse 响应中的支持模型条目。
-type supportedModelResponse struct {
- Name string `json:"name"`
- Platform string `json:"platform"`
- Pricing *channelModelPricingResponse `json:"pricing"`
-}
-
-// availableChannelResponse 管理员视图完整字段集。
-type availableChannelResponse struct {
- ID int64 `json:"id"`
- Name string `json:"name"`
- Description string `json:"description"`
- Status string `json:"status"`
- BillingModelSource string `json:"billing_model_source"`
- RestrictModels bool `json:"restrict_models"`
- Groups []availableGroupResponse `json:"groups"`
- SupportedModels []supportedModelResponse `json:"supported_models"`
-}
-
-// availableChannelToAdminResponse 将 service 层的 AvailableChannel 转为管理员 DTO。
-// 同 package 内复用;也用于构造测试 fixture。
-func availableChannelToAdminResponse(ch service.AvailableChannel) availableChannelResponse {
- groups := make([]availableGroupResponse, 0, len(ch.Groups))
- for _, g := range ch.Groups {
- groups = append(groups, availableGroupResponse{ID: g.ID, Name: g.Name, Platform: g.Platform})
- }
- models := make([]supportedModelResponse, 0, len(ch.SupportedModels))
- for i := range ch.SupportedModels {
- m := ch.SupportedModels[i]
- var pricing *channelModelPricingResponse
- if m.Pricing != nil {
- p := pricingToResponse(m.Pricing)
- pricing = &p
- }
- models = append(models, supportedModelResponse{
- Name: m.Name,
- Platform: m.Platform,
- Pricing: pricing,
- })
- }
- return availableChannelResponse{
- ID: ch.ID,
- Name: ch.Name,
- Description: ch.Description,
- Status: ch.Status,
- BillingModelSource: ch.BillingModelSource,
- RestrictModels: ch.RestrictModels,
- Groups: groups,
- SupportedModels: models,
- }
-}
-
-// List 列出所有可用渠道(管理员视图)。
-// GET /api/v1/admin/channels/available
-func (h *AvailableChannelHandler) List(c *gin.Context) {
- channels, err := h.channelService.ListAvailable(c.Request.Context())
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
-
- out := make([]availableChannelResponse, 0, len(channels))
- for _, ch := range channels {
- out = append(out, availableChannelToAdminResponse(ch))
- }
- response.Success(c, gin.H{"items": out})
-}
diff --git a/backend/internal/handler/admin/available_channel_handler_test.go b/backend/internal/handler/admin/available_channel_handler_test.go
deleted file mode 100644
index 7d249383..00000000
--- a/backend/internal/handler/admin/available_channel_handler_test.go
+++ /dev/null
@@ -1,57 +0,0 @@
-//go:build unit
-
-package admin
-
-import (
- "encoding/json"
- "testing"
-
- "github.com/Wei-Shaw/sub2api/internal/service"
- "github.com/stretchr/testify/require"
-)
-
-func TestAvailableChannelToAdminResponse_IncludesFullDTO(t *testing.T) {
- // 管理员视图应包含 id / status / billing_model_source / restrict_models 等
- // 管理字段;mapper 是纯透传,BillingModelSource 的默认回填由 service 层负责。
- input := service.AvailableChannel{
- ID: 42,
- Name: "ch",
- Description: "d",
- Status: service.StatusActive,
- BillingModelSource: service.BillingModelSourceChannelMapped,
- RestrictModels: true,
- Groups: []service.AvailableGroupRef{
- {ID: 1, Name: "g1", Platform: "anthropic"},
- },
- SupportedModels: []service.SupportedModel{
- {Name: "claude-sonnet-4-6", Platform: "anthropic"},
- },
- }
-
- resp := availableChannelToAdminResponse(input)
- require.Equal(t, int64(42), resp.ID)
- require.Equal(t, "ch", resp.Name)
- require.Equal(t, service.StatusActive, resp.Status)
- require.Equal(t, service.BillingModelSourceChannelMapped, resp.BillingModelSource)
- require.True(t, resp.RestrictModels)
- require.Len(t, resp.Groups, 1)
- require.Len(t, resp.SupportedModels, 1)
-
- // JSON 层验证管理字段确实会被序列化。
- raw, err := json.Marshal(resp)
- require.NoError(t, err)
- var decoded map[string]any
- require.NoError(t, json.Unmarshal(raw, &decoded))
- for _, key := range []string{"id", "status", "billing_model_source", "restrict_models", "groups", "supported_models"} {
- _, exists := decoded[key]
- require.Truef(t, exists, "admin DTO must expose %q", key)
- }
-}
-
-func TestAvailableChannelToAdminResponse_PreservesExplicitBillingSource(t *testing.T) {
- input := service.AvailableChannel{
- BillingModelSource: service.BillingModelSourceUpstream,
- }
- resp := availableChannelToAdminResponse(input)
- require.Equal(t, service.BillingModelSourceUpstream, resp.BillingModelSource)
-}
diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go
index a35d8041..aee9d927 100644
--- a/backend/internal/handler/handler.go
+++ b/backend/internal/handler/handler.go
@@ -33,7 +33,6 @@ type AdminHandlers struct {
Channel *admin.ChannelHandler
ChannelMonitor *admin.ChannelMonitorHandler
ChannelMonitorTemplate *admin.ChannelMonitorRequestTemplateHandler
- AvailableChannel *admin.AvailableChannelHandler
Payment *admin.PaymentHandler
}
diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go
index c9296b44..6d175488 100644
--- a/backend/internal/handler/wire.go
+++ b/backend/internal/handler/wire.go
@@ -36,7 +36,6 @@ func ProvideAdminHandlers(
channelHandler *admin.ChannelHandler,
channelMonitorHandler *admin.ChannelMonitorHandler,
channelMonitorTemplateHandler *admin.ChannelMonitorRequestTemplateHandler,
- availableChannelHandler *admin.AvailableChannelHandler,
paymentHandler *admin.PaymentHandler,
) *AdminHandlers {
return &AdminHandlers{
@@ -67,7 +66,6 @@ func ProvideAdminHandlers(
Channel: channelHandler,
ChannelMonitor: channelMonitorHandler,
ChannelMonitorTemplate: channelMonitorTemplateHandler,
- AvailableChannel: availableChannelHandler,
Payment: paymentHandler,
}
}
@@ -170,7 +168,6 @@ var ProviderSet = wire.NewSet(
admin.NewChannelHandler,
admin.NewChannelMonitorHandler,
admin.NewChannelMonitorRequestTemplateHandler,
- admin.NewAvailableChannelHandler,
admin.NewPaymentHandler,
// AdminHandlers and Handlers constructors
diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go
index e4b5c548..4b796d55 100644
--- a/backend/internal/server/routes/admin.go
+++ b/backend/internal/server/routes/admin.go
@@ -560,7 +560,6 @@ func registerChannelRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
channels := admin.Group("/channels")
{
channels.GET("", h.Admin.Channel.List)
- channels.GET("/available", h.Admin.AvailableChannel.List)
channels.GET("/model-pricing", h.Admin.Channel.GetModelDefaultPricing)
channels.GET("/:id", h.Admin.Channel.GetByID)
channels.POST("", h.Admin.Channel.Create)
diff --git a/frontend/src/api/admin/channels.ts b/frontend/src/api/admin/channels.ts
index 7ad4af28..9d430134 100644
--- a/frontend/src/api/admin/channels.ts
+++ b/frontend/src/api/admin/channels.ts
@@ -164,42 +164,5 @@ export async function getModelDefaultPricing(model: string): Promise {
- const { data } = await apiClient.get('/admin/channels/available', {
- signal: options?.signal
- })
- return data.items
-}
-
-const channelsAPI = { list, getById, create, update, remove, getModelDefaultPricing, listAvailable }
+const channelsAPI = { list, getById, create, update, remove, getModelDefaultPricing }
export default channelsAPI
diff --git a/frontend/src/router/index.ts b/frontend/src/router/index.ts
index 567876b6..dc886b23 100644
--- a/frontend/src/router/index.ts
+++ b/frontend/src/router/index.ts
@@ -370,18 +370,6 @@ const routes: RouteRecordRaw[] = [
descriptionKey: 'admin.groups.description'
}
},
- {
- path: '/admin/available-channels',
- name: 'AdminAvailableChannels',
- component: () => import('@/views/admin/AvailableChannelsView.vue'),
- meta: {
- requiresAuth: true,
- requiresAdmin: true,
- title: 'Available Channels',
- titleKey: 'admin.availableChannels.title',
- descriptionKey: 'admin.availableChannels.description'
- }
- },
{
path: '/admin/channels',
redirect: '/admin/channels/pricing'
diff --git a/frontend/src/views/admin/AvailableChannelsView.vue b/frontend/src/views/admin/AvailableChannelsView.vue
deleted file mode 100644
index a9b2462f..00000000
--- a/frontend/src/views/admin/AvailableChannelsView.vue
+++ /dev/null
@@ -1,164 +0,0 @@
-
-